diff --git a/handler.go b/handler.go index 996e32e..278866b 100644 --- a/handler.go +++ b/handler.go @@ -15,10 +15,10 @@ type ProxyHandler struct { logger *CondLogger dialer ContextDialer httptransport http.RoundTripper - resolver *Resolver } func NewProxyHandler(dialer ContextDialer, resolver *Resolver, logger *CondLogger) *ProxyHandler { + dialer = NewRetryDialer(dialer, resolver, logger) httptransport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -30,7 +30,6 @@ func NewProxyHandler(dialer ContextDialer, resolver *Resolver, logger *CondLogge logger: logger, dialer: dialer, httptransport: httptransport, - resolver: resolver, } } diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..6444646 --- /dev/null +++ b/retry.go @@ -0,0 +1,39 @@ +package main + +import ( + "context" + "net" +) + +type RetryDialer struct { + dialer ContextDialer + resolver *Resolver + logger *CondLogger +} + +func NewRetryDialer(dialer ContextDialer, resolver *Resolver, logger *CondLogger) *RetryDialer { + return &RetryDialer{ + dialer: dialer, + resolver: resolver, + logger: logger, + } +} + +func (d *RetryDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.dialer.DialContext(ctx, network, address) + if err == UpstreamBlockedError { + d.logger.Info("Destination %s blocked by upstream. Rescuing it with resolve&tunnel workaround.", address) + host, port, err1 := net.SplitHostPort(address) + if err1 != nil { + return conn, err + } + + ips := d.resolver.Resolve(host) + if len(ips) == 0 { + return conn, err + } + + return d.dialer.DialContext(ctx, network, net.JoinHostPort(ips[0], port)) + } + return conn, err +} diff --git a/upstream.go b/upstream.go index 018d428..3a4a651 100644 --- a/upstream.go +++ b/upstream.go @@ -19,6 +19,8 @@ const ( PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization" ) +var UpstreamBlockedError = errors.New("blocked by upstream") + type ContextDialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } @@ -104,6 +106,10 @@ func (d *ProxyDialer) DialContext(ctx context.Context, network, address string) } if proxyResp.StatusCode != http.StatusOK { + if proxyResp.StatusCode == http.StatusForbidden && + proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" { + return nil, UpstreamBlockedError + } return nil, errors.New("Bad response from upstream proxy server") } diff --git a/utils.go b/utils.go index 92083c1..2820a38 100644 --- a/utils.go +++ b/utils.go @@ -292,77 +292,3 @@ func copyBody(wr io.Writer, body io.Reader) { } } } - -func rewriteConnectReq(req *http.Request, resolver *Resolver) error { - origHost := req.Host - origAddr, origPort, err := net.SplitHostPort(origHost) - if err == nil { - origHost = origAddr - } - addrs := resolver.Resolve(origHost) - if len(addrs) == 0 { - return errors.New("Can't resolve host") - } - if origPort == "" { - req.URL.Host = addrs[0] - req.Host = addrs[0] - req.RequestURI = addrs[0] - } else { - req.URL.Host = net.JoinHostPort(addrs[0], origPort) - req.Host = net.JoinHostPort(addrs[0], origPort) - req.RequestURI = net.JoinHostPort(addrs[0], origPort) - } - return nil -} - -func rewriteReq(req *http.Request, resolver *Resolver) error { - origHost := req.URL.Host - origAddr, origPort, err := net.SplitHostPort(origHost) - if err == nil { - origHost = origAddr - } - addrs := resolver.Resolve(origHost) - if len(addrs) == 0 { - return errors.New("Can't resolve host") - } - if origPort == "" { - req.URL.Host = addrs[0] - req.Host = addrs[0] - } else { - req.URL.Host = net.JoinHostPort(addrs[0], origPort) - req.Host = net.JoinHostPort(addrs[0], origPort) - } - req.Header.Set("Host", origHost) - return nil -} - -func makeConnReq(uri string, resolver *Resolver) (*http.Request, error) { - parsed_url, err := url.Parse(uri) - if err != nil { - return nil, err - } - origAddr, origPort, err := net.SplitHostPort(parsed_url.Host) - if err != nil { - origAddr = parsed_url.Host - switch strings.ToLower(parsed_url.Scheme) { - case "https": - origPort = "443" - case "http": - origPort = "80" - default: - return nil, errors.New("Unknown scheme") - } - } - addrs := resolver.Resolve(origAddr) - if len(addrs) == 0 { - return nil, errors.New("Can't resolve host") - } - new_uri := net.JoinHostPort(addrs[0], origPort) - req, err := http.NewRequest("CONNECT", "http://"+new_uri, nil) - if err != nil { - return nil, err - } - req.RequestURI = new_uri - req.Host = new_uri - return req, nil -}