skip upstream agent resolve

This commit is contained in:
Vladislav Yarmak
2021-03-16 00:46:49 +02:00
parent 3b09f31616
commit 3cb79059b2
3 changed files with 92 additions and 30 deletions

View File

@@ -2,36 +2,51 @@ package main
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
)
type AuthProvider func() string
type ProxyHandler struct {
auth AuthProvider
upstream string
upstreamAddr string
tlsName string
logger *CondLogger
dialer *net.Dialer
httptransport http.RoundTripper
resolver *Resolver
}
func NewProxyHandler(upstream string, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
proxyurl, err := url.Parse("https://" + upstream)
if err != nil {
panic(err)
func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
httptransport := &http.Transport{
Proxy: http.ProxyURL(proxyurl),
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
Proxy: http.ProxyURL(upstream.URL()),
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", netaddr)
},
}
return &ProxyHandler{
auth: auth,
upstream: upstream,
upstreamAddr: netaddr,
tlsName: upstream.TLSName,
logger: logger,
dialer: dialer,
httptransport: httptransport,
resolver: resolver,
}
@@ -48,17 +63,25 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
return
}
conn, err := tls.Dial("tcp", s.upstream, nil)
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()
if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}
_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write tls upstream: %v", err)
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
s.logger.Error("Can't write upstream: %v", err)
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
return
}
bufrd := bufio.NewReader(conn)
@@ -74,14 +97,22 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
req.URL.String())
conn.Close()
conn, err = tls.Dial("tcp", s.upstream, nil)
conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()
if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}
err = rewriteConnectReq(req, s.resolver)
if err != nil {
s.logger.Error("Can't rewrite request: %v", err)
@@ -101,7 +132,6 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
return
}
} else {
defer conn.Close()
responseBytes, err = httputil.DumpResponse(proxyResp, false)
if err != nil {
s.logger.Error("Can't dump response: %v", err)
@@ -160,8 +190,8 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
proxyReq.Header.Set("Proxy-Authorization", s.auth())
rawreq, _ := httputil.DumpRequest(proxyReq, false)
// Prepare upstream TLS conn
conn, err := tls.Dial("tcp", s.upstream, nil)
// Prepare upstream conn
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
@@ -169,6 +199,13 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
}
defer conn.Close()
if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}
// Send proxy request
_, err = conn.Write(rawreq)
if err != nil {

View File

@@ -106,9 +106,9 @@ func (c *FallbackConfig) ShuffleAgents() {
func (c *FallbackConfig) Clone() *FallbackConfig {
return &FallbackConfig{
Agents: append([]FallbackAgent(nil), c.Agents...),
Agents: append([]FallbackAgent(nil), c.Agents...),
UpdatedAt: c.UpdatedAt,
TTL: c.TTL,
TTL: c.TTL,
}
}
@@ -338,7 +338,7 @@ func httpClientWithProxy(agent *FallbackAgent) *http.Client {
t.Proxy = http.ProxyURL(agent.ToProxy())
addr := net.JoinHostPort(agent.IP, fmt.Sprintf("%d", agent.Port))
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp4", addr)
return dialer.DialContext(ctx, "tcp", addr)
}
}
return &http.Client{

View File

@@ -18,6 +18,26 @@ import (
"time"
)
type Endpoint struct {
Host string
Port uint16
TLSName string
}
func (e *Endpoint) URL() *url.URL {
if e.TLSName == "" {
return &url.URL{
Scheme: "http",
Host: net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port)),
}
} else {
return &url.URL{
Scheme: "https",
Host: net.JoinHostPort(e.TLSName, fmt.Sprintf("%d", e.Port)),
}
}
}
func basic_auth_header(login, password string) string {
return "basic " + base64.StdEncoding.EncodeToString(
[]byte(login+":"+password))
@@ -123,14 +143,15 @@ func print_proxies(country string, proxy_type string, limit uint, timeout time.D
return 0
}
func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (string, error) {
var hostname string
for k := range tunnels.IPList {
func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (*Endpoint, error) {
var hostname, ip string
for k, v := range tunnels.IPList {
hostname = k
ip = v
break
}
if hostname == "" {
return "", errors.New("No tunnels found in API response")
if hostname == "" || ip == "" {
return nil, errors.New("No tunnels found in API response")
}
var port uint16
@@ -157,10 +178,14 @@ func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_po
port = tunnels.Port.Peer
}
} else {
return "", errors.New("Unsupported port type")
return nil, errors.New("Unsupported port type")
}
}
return net.JoinHostPort(hostname, strconv.FormatUint(uint64(port), 10)), nil
return &Endpoint{
Host: ip,
Port: port,
TLSName: hostname,
}, nil
}
// Hop-by-hop headers. These are removed when sent to the backend.