mirror of
https://github.com/Snawoot/hola-proxy.git
synced 2026-04-03 12:48:15 +00:00
skip upstream agent resolve
This commit is contained in:
77
handler.go
77
handler.go
@@ -2,36 +2,51 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthProvider func() string
|
||||
|
||||
type ProxyHandler struct {
|
||||
auth AuthProvider
|
||||
upstream string
|
||||
upstreamAddr string
|
||||
tlsName string
|
||||
logger *CondLogger
|
||||
dialer *net.Dialer
|
||||
httptransport http.RoundTripper
|
||||
resolver *Resolver
|
||||
}
|
||||
|
||||
func NewProxyHandler(upstream string, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
||||
proxyurl, err := url.Parse("https://" + upstream)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
|
||||
httptransport := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyurl),
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
Proxy: http.ProxyURL(upstream.URL()),
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", netaddr)
|
||||
},
|
||||
}
|
||||
return &ProxyHandler{
|
||||
auth: auth,
|
||||
upstream: upstream,
|
||||
upstreamAddr: netaddr,
|
||||
tlsName: upstream.TLSName,
|
||||
logger: logger,
|
||||
dialer: dialer,
|
||||
httptransport: httptransport,
|
||||
resolver: resolver,
|
||||
}
|
||||
@@ -48,17 +63,25 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", s.upstream, nil)
|
||||
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial tls upstream: %v", err)
|
||||
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
||||
s.logger.Error("Can't dial upstream: %v", err)
|
||||
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't write tls upstream: %v", err)
|
||||
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
||||
s.logger.Error("Can't write upstream: %v", err)
|
||||
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
bufrd := bufio.NewReader(conn)
|
||||
@@ -74,14 +97,22 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
||||
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
|
||||
req.URL.String())
|
||||
conn.Close()
|
||||
conn, err = tls.Dial("tcp", s.upstream, nil)
|
||||
|
||||
conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial tls upstream: %v", err)
|
||||
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
||||
s.logger.Error("Can't dial upstream: %v", err)
|
||||
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
err = rewriteConnectReq(req, s.resolver)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't rewrite request: %v", err)
|
||||
@@ -101,7 +132,6 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
defer conn.Close()
|
||||
responseBytes, err = httputil.DumpResponse(proxyResp, false)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dump response: %v", err)
|
||||
@@ -160,8 +190,8 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
proxyReq.Header.Set("Proxy-Authorization", s.auth())
|
||||
rawreq, _ := httputil.DumpRequest(proxyReq, false)
|
||||
|
||||
// Prepare upstream TLS conn
|
||||
conn, err := tls.Dial("tcp", s.upstream, nil)
|
||||
// Prepare upstream conn
|
||||
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial tls upstream: %v", err)
|
||||
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
||||
@@ -169,6 +199,13 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
// Send proxy request
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,9 +106,9 @@ func (c *FallbackConfig) ShuffleAgents() {
|
||||
|
||||
func (c *FallbackConfig) Clone() *FallbackConfig {
|
||||
return &FallbackConfig{
|
||||
Agents: append([]FallbackAgent(nil), c.Agents...),
|
||||
Agents: append([]FallbackAgent(nil), c.Agents...),
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
TTL: c.TTL,
|
||||
TTL: c.TTL,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +338,7 @@ func httpClientWithProxy(agent *FallbackAgent) *http.Client {
|
||||
t.Proxy = http.ProxyURL(agent.ToProxy())
|
||||
addr := net.JoinHostPort(agent.IP, fmt.Sprintf("%d", agent.Port))
|
||||
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp4", addr)
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
|
||||
39
utils.go
39
utils.go
@@ -18,6 +18,26 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
Host string
|
||||
Port uint16
|
||||
TLSName string
|
||||
}
|
||||
|
||||
func (e *Endpoint) URL() *url.URL {
|
||||
if e.TLSName == "" {
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port)),
|
||||
}
|
||||
} else {
|
||||
return &url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(e.TLSName, fmt.Sprintf("%d", e.Port)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func basic_auth_header(login, password string) string {
|
||||
return "basic " + base64.StdEncoding.EncodeToString(
|
||||
[]byte(login+":"+password))
|
||||
@@ -123,14 +143,15 @@ func print_proxies(country string, proxy_type string, limit uint, timeout time.D
|
||||
return 0
|
||||
}
|
||||
|
||||
func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (string, error) {
|
||||
var hostname string
|
||||
for k := range tunnels.IPList {
|
||||
func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_port_field string) (*Endpoint, error) {
|
||||
var hostname, ip string
|
||||
for k, v := range tunnels.IPList {
|
||||
hostname = k
|
||||
ip = v
|
||||
break
|
||||
}
|
||||
if hostname == "" {
|
||||
return "", errors.New("No tunnels found in API response")
|
||||
if hostname == "" || ip == "" {
|
||||
return nil, errors.New("No tunnels found in API response")
|
||||
}
|
||||
|
||||
var port uint16
|
||||
@@ -157,10 +178,14 @@ func get_endpoint(tunnels *ZGetTunnelsResponse, typ string, trial bool, force_po
|
||||
port = tunnels.Port.Peer
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("Unsupported port type")
|
||||
return nil, errors.New("Unsupported port type")
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(hostname, strconv.FormatUint(uint64(port), 10)), nil
|
||||
return &Endpoint{
|
||||
Host: ip,
|
||||
Port: port,
|
||||
TLSName: hostname,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
|
||||
Reference in New Issue
Block a user