Files
windscribe-proxy/resolver.go
Vladislav Yarmak d0b50bc52b resolver: WIP
2021-06-25 17:58:15 +03:00

192 lines
4.1 KiB
Go

package main
import (
"context"
"errors"
"net"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
"github.com/ReneKroon/ttlcache/v2"
)
type Resolver struct {
upstream upstream.Upstream
}
const (
DOT = 0x2e
DNS_CACHE_SIZE_LIMIT = 1024
)
func NewResolver(address string, timeout time.Duration) (*Resolver, error) {
opts := upstream.Options{Timeout: timeout}
u, err := upstream.AddressToUpstream(address, opts)
if err != nil {
return nil, err
}
return &Resolver{upstream: u}, nil
}
func (r *Resolver) ResolveA(domain string) []string {
res := make([]string, 0)
if len(domain) == 0 {
return res
}
domain = absDomain(domain)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
res = append(res, a.A.String())
}
}
return res
}
func (r *Resolver) ResolveAAAA(domain string) []string {
res := make([]string, 0)
if len(domain) == 0 {
return res
}
domain = absDomain(domain)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET},
}
reply, err := r.upstream.Exchange(&req)
if err != nil {
return res
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.AAAA); ok {
res = append(res, a.AAAA.String())
}
}
return res
}
func (r *Resolver) Resolve(domain string) []string {
res := r.ResolveA(domain)
if len(res) == 0 {
res = r.ResolveAAAA(domain)
}
return res
}
type ResolvingDialer struct {
next ContextDialer
upstream upstream.Upstream
cache4 *ttlcache.Cache
cache6 *ttlcache.Cache
}
func NewResolvingDialer(resolverAddress string, timeout time.Duration, next ContextDialer) (*ResolvingDialer, error) {
opts := upstream.Options{Timeout: timeout}
u, err := upstream.AddressToUpstream(resolverAddress, opts)
if err != nil {
return nil, err
}
cache4 := ttlcache.NewCache()
cache6 := ttlcache.NewCache()
d := &ResolvingDialer{
upstream: u,
next: next,
cache4: cache4,
cache6: cache6,
}
cache4.SetLoaderFunction(d.resolveA)
cache6.SetLoaderFunction(d.resolveAAAA)
cache4.SetCacheSizeLimit(DNS_CACHE_SIZE_LIMIT)
cache6.SetCacheSizeLimit(DNS_CACHE_SIZE_LIMIT)
cache4.SkipTTLExtensionOnHit(true)
cache6.SkipTTLExtensionOnHit(true)
return d, nil
}
func (d *ResolvingDialer) resolveA(domain string) (interface{}, time.Duration, error) {
return d.resolve(domain, dns.TypeA)
}
func (d *ResolvingDialer) resolveAAAA(domain string) (interface{}, time.Duration, error) {
return d.resolve(domain, dns.TypeAAAA)
}
func (d *ResolvingDialer) resolve(domain string, typ uint16) (string, time.Duration, error) {
if len(domain) == 0 {
return "", 0, errors.New("empty domain name")
}
domain = absDomain(domain)
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: domain, Qtype: typ, Qclass: dns.ClassINET},
}
reply, err := d.upstream.Exchange(&req)
if err != nil {
return "", 0, err
}
for _, rr := range reply.Answer {
if a, ok := rr.(*dns.A); ok {
return a.A.String(), (time.Second * time.Duration(a.Hdr.Ttl)), nil
}
}
return "", 0, errors.New("no data in DNS response")
}
func (d *ResolvingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
name, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if net.ParseIP(name) != nil || len(name) == 0 {
// Address is already in numeric form
return d.next.DialContext(ctx, network, address)
}
if len(network) == 0 {
return d.next.DialContext(ctx, network, address)
}
switch network[len(network)-1] {
case '4':
//
case '6':
//
default:
//
}
return d.next.DialContext(ctx, network, address)
}
func (d *ResolvingDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func absDomain(domain string) string {
if domain == "" {
return ""
}
if domain[len(domain)-1] != DOT {
domain = domain + "."
}
return domain
}