mirror of
https://github.com/Snawoot/hola-proxy.git
synced 2026-04-02 13:38:14 +00:00
resolver fixes
This commit is contained in:
79
resolver.go
79
resolver.go
@@ -14,14 +14,23 @@ import (
|
||||
)
|
||||
|
||||
func FromURL(u string) (*net.Resolver, error) {
|
||||
begin:
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host := parsed.Hostname()
|
||||
port := parsed.Port()
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "", "dns":
|
||||
switch scheme := strings.ToLower(parsed.Scheme); scheme {
|
||||
case "":
|
||||
switch {
|
||||
case strings.HasPrefix(u, "//"):
|
||||
u = "dns:" + u
|
||||
default:
|
||||
u = "dns://" + u
|
||||
}
|
||||
goto begin
|
||||
case "udp", "dns":
|
||||
if port == "" {
|
||||
port = "53"
|
||||
}
|
||||
@@ -31,12 +40,20 @@ func FromURL(u string) (*net.Resolver, error) {
|
||||
port = "53"
|
||||
}
|
||||
return NewTCPResolver(net.JoinHostPort(host, port)), nil
|
||||
case "http", "https":
|
||||
case "http", "https", "doh":
|
||||
if port == "" {
|
||||
port = "443"
|
||||
if scheme == "http" {
|
||||
port = "80"
|
||||
} else {
|
||||
port = "443"
|
||||
}
|
||||
}
|
||||
if scheme == "doh" {
|
||||
parsed.Scheme = "https"
|
||||
u = parsed.String()
|
||||
}
|
||||
return dns.NewDoHResolver(u, dns.DoHAddresses(net.JoinHostPort(host, port)))
|
||||
case "tls":
|
||||
case "tls", "dot":
|
||||
if port == "" {
|
||||
port = "853"
|
||||
}
|
||||
@@ -55,12 +72,7 @@ type FastResolver struct {
|
||||
upstreams []LookupNetIPer
|
||||
}
|
||||
|
||||
type lookupReply struct {
|
||||
addrs []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
func FastResolverFromURLs(urls ...string) (*FastResolver, error) {
|
||||
func FastResolverFromURLs(urls ...string) (LookupNetIPer, error) {
|
||||
resolvers := make([]LookupNetIPer, 0, len(urls))
|
||||
for i, u := range urls {
|
||||
res, err := FromURL(u)
|
||||
@@ -69,6 +81,9 @@ func FastResolverFromURLs(urls ...string) (*FastResolver, error) {
|
||||
}
|
||||
resolvers = append(resolvers, res)
|
||||
}
|
||||
if len(resolvers) == 1 {
|
||||
return resolvers[0], nil
|
||||
}
|
||||
return NewFastResolver(resolvers...), nil
|
||||
}
|
||||
|
||||
@@ -80,34 +95,38 @@ func NewFastResolver(resolvers ...LookupNetIPer) *FastResolver {
|
||||
|
||||
func (r FastResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||
ctx, cl := context.WithCancel(ctx)
|
||||
drain := make(chan lookupReply, len(r.upstreams))
|
||||
defer cl()
|
||||
errors := make(chan error)
|
||||
success := make(chan []netip.Addr)
|
||||
for _, res := range r.upstreams {
|
||||
go func(res LookupNetIPer) {
|
||||
addrs, err := res.LookupNetIP(ctx, network, host)
|
||||
drain <- lookupReply{addrs, err}
|
||||
if err == nil {
|
||||
select {
|
||||
case success <- addrs:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case errors <- err:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}(res)
|
||||
}
|
||||
|
||||
i := 0
|
||||
var resAddrs []netip.Addr
|
||||
var resErr error
|
||||
for ; i < len(r.upstreams); i++ {
|
||||
pair := <-drain
|
||||
if pair.err != nil {
|
||||
resErr = multierror.Append(resErr, pair.err)
|
||||
} else {
|
||||
cl()
|
||||
resAddrs = pair.addrs
|
||||
resErr = nil
|
||||
break
|
||||
for _ = range r.upstreams {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resAddrs := <-success:
|
||||
return resAddrs, nil
|
||||
case err := <-errors:
|
||||
resErr = multierror.Append(resErr, err)
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for i = i + 1; i < len(r.upstreams); i++ {
|
||||
<-drain
|
||||
}
|
||||
}()
|
||||
return resAddrs, resErr
|
||||
return nil, resErr
|
||||
}
|
||||
|
||||
func NewPlainResolver(addr string) *net.Resolver {
|
||||
|
||||
Reference in New Issue
Block a user