diff --git a/resolver.go b/resolver.go index d5c1420..30c2e7a 100644 --- a/resolver.go +++ b/resolver.go @@ -14,14 +14,23 @@ import ( ) func FromURL(u string) (*net.Resolver, error) { +begin: parsed, err := url.Parse(u) if err != nil { return nil, err } host := parsed.Hostname() port := parsed.Port() - switch strings.ToLower(parsed.Scheme) { - case "", "dns": + switch scheme := strings.ToLower(parsed.Scheme); scheme { + case "": + switch { + case strings.HasPrefix(u, "//"): + u = "dns:" + u + default: + u = "dns://" + u + } + goto begin + case "udp", "dns": if port == "" { port = "53" } @@ -31,12 +40,20 @@ func FromURL(u string) (*net.Resolver, error) { port = "53" } return NewTCPResolver(net.JoinHostPort(host, port)), nil - case "http", "https": + case "http", "https", "doh": if port == "" { - port = "443" + if scheme == "http" { + port = "80" + } else { + port = "443" + } + } + if scheme == "doh" { + parsed.Scheme = "https" + u = parsed.String() } return dns.NewDoHResolver(u, dns.DoHAddresses(net.JoinHostPort(host, port))) - case "tls": + case "tls", "dot": if port == "" { port = "853" } @@ -55,12 +72,7 @@ type FastResolver struct { upstreams []LookupNetIPer } -type lookupReply struct { - addrs []netip.Addr - err error -} - -func FastResolverFromURLs(urls ...string) (*FastResolver, error) { +func FastResolverFromURLs(urls ...string) (LookupNetIPer, error) { resolvers := make([]LookupNetIPer, 0, len(urls)) for i, u := range urls { res, err := FromURL(u) @@ -69,6 +81,9 @@ func FastResolverFromURLs(urls ...string) (*FastResolver, error) { } resolvers = append(resolvers, res) } + if len(resolvers) == 1 { + return resolvers[0], nil + } return NewFastResolver(resolvers...), nil } @@ -80,34 +95,38 @@ func NewFastResolver(resolvers ...LookupNetIPer) *FastResolver { func (r FastResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { ctx, cl := context.WithCancel(ctx) - drain := make(chan lookupReply, len(r.upstreams)) + defer cl() + errors := make(chan error) + success := make(chan []netip.Addr) for _, res := range r.upstreams { go func(res LookupNetIPer) { addrs, err := res.LookupNetIP(ctx, network, host) - drain <- lookupReply{addrs, err} + if err == nil { + select { + case success <- addrs: + case <-ctx.Done(): + } + } else { + select { + case errors <- err: + case <-ctx.Done(): + } + } }(res) } - i := 0 - var resAddrs []netip.Addr var resErr error - for ; i < len(r.upstreams); i++ { - pair := <-drain - if pair.err != nil { - resErr = multierror.Append(resErr, pair.err) - } else { - cl() - resAddrs = pair.addrs - resErr = nil - break + for _ = range r.upstreams { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resAddrs := <-success: + return resAddrs, nil + case err := <-errors: + resErr = multierror.Append(resErr, err) } } - go func() { - for i = i + 1; i < len(r.upstreams); i++ { - <-drain - } - }() - return resAddrs, resErr + return nil, resErr } func NewPlainResolver(addr string) *net.Resolver {