mirror of
https://github.com/Snawoot/hola-proxy.git
synced 2026-04-02 16:28:11 +00:00
258 lines
7.3 KiB
Go
258 lines
7.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type AuthProvider func() string
|
|
|
|
type ProxyHandler struct {
|
|
auth AuthProvider
|
|
upstreamAddr string
|
|
tlsName string
|
|
logger *CondLogger
|
|
dialer *net.Dialer
|
|
httptransport http.RoundTripper
|
|
resolver *Resolver
|
|
}
|
|
|
|
func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
|
|
httptransport := &http.Transport{
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
Proxy: http.ProxyURL(upstream.URL()),
|
|
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
return dialer.DialContext(ctx, "tcp", netaddr)
|
|
},
|
|
}
|
|
return &ProxyHandler{
|
|
auth: auth,
|
|
upstreamAddr: netaddr,
|
|
tlsName: upstream.TLSName,
|
|
logger: logger,
|
|
dialer: dialer,
|
|
httptransport: httptransport,
|
|
resolver: resolver,
|
|
}
|
|
}
|
|
|
|
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
|
s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL)
|
|
if strings.ToUpper(req.Method) == "CONNECT" {
|
|
req.Header.Set("Proxy-Authorization", s.auth())
|
|
rawreq, err := httputil.DumpRequest(req, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump request: %v", err)
|
|
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial upstream: %v", err)
|
|
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
if s.tlsName != "" {
|
|
conn = tls.Client(conn, &tls.Config{
|
|
ServerName: s.tlsName,
|
|
})
|
|
defer conn.Close()
|
|
}
|
|
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write upstream: %v", err)
|
|
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
bufrd := bufio.NewReader(conn)
|
|
proxyResp, err := http.ReadResponse(bufrd, req)
|
|
responseBytes := make([]byte, 0)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
if proxyResp.StatusCode == http.StatusForbidden &&
|
|
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
|
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
|
|
req.URL.String())
|
|
|
|
conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial upstream: %v", err)
|
|
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
if s.tlsName != "" {
|
|
conn = tls.Client(conn, &tls.Config{
|
|
ServerName: s.tlsName,
|
|
})
|
|
defer conn.Close()
|
|
}
|
|
|
|
err = rewriteConnectReq(req, s.resolver)
|
|
if err != nil {
|
|
s.logger.Error("Can't rewrite request: %v", err)
|
|
http.Error(wr, "Can't rewrite request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
rawreq, err = httputil.DumpRequest(req, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump request: %v", err)
|
|
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
} else {
|
|
responseBytes, err = httputil.DumpResponse(proxyResp, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump response: %v", err)
|
|
http.Error(wr, "Can't dump response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
buffered := bufrd.Buffered()
|
|
if buffered > 0 {
|
|
trailer := make([]byte, buffered)
|
|
bufrd.Read(trailer)
|
|
responseBytes = append(responseBytes, trailer...)
|
|
}
|
|
}
|
|
bufrd = nil
|
|
|
|
// Upgrade client connection
|
|
localconn, _, err := hijack(wr)
|
|
if err != nil {
|
|
s.logger.Error("Can't hijack client connection: %v", err)
|
|
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer localconn.Close()
|
|
|
|
if len(responseBytes) > 0 {
|
|
_, err = localconn.Write(responseBytes)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
proxy(req.Context(), localconn, conn)
|
|
} else {
|
|
delHopHeaders(req.Header)
|
|
orig_req := req.Clone(req.Context())
|
|
req.RequestURI = ""
|
|
req.Header.Set("Proxy-Authorization", s.auth())
|
|
resp, err := s.httptransport.RoundTrip(req)
|
|
if err != nil {
|
|
s.logger.Error("HTTP fetch error: %v", err)
|
|
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if resp.StatusCode == http.StatusForbidden &&
|
|
resp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
|
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&tunnel workaround.",
|
|
req.URL.String())
|
|
resp.Body.Close()
|
|
|
|
// Prepare tunnel request
|
|
proxyReq, err := makeConnReq(orig_req.RequestURI, s.resolver)
|
|
if err != nil {
|
|
s.logger.Error("Can't rewrite request: %v", err)
|
|
http.Error(wr, "Can't rewrite request", http.StatusBadGateway)
|
|
return
|
|
}
|
|
proxyReq.Header.Set("Proxy-Authorization", s.auth())
|
|
rawreq, _ := httputil.DumpRequest(proxyReq, false)
|
|
|
|
// Prepare upstream conn
|
|
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial tls upstream: %v", err)
|
|
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
if s.tlsName != "" {
|
|
conn = tls.Client(conn, &tls.Config{
|
|
ServerName: s.tlsName,
|
|
})
|
|
defer conn.Close()
|
|
}
|
|
|
|
// Send proxy request
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Read proxy response
|
|
bufrd := bufio.NewReader(conn)
|
|
proxyResp, err := http.ReadResponse(bufrd, proxyReq)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
if proxyResp.StatusCode != http.StatusOK {
|
|
delHopHeaders(proxyResp.Header)
|
|
copyHeader(wr.Header(), proxyResp.Header)
|
|
wr.WriteHeader(proxyResp.StatusCode)
|
|
}
|
|
|
|
// Send tunneled request
|
|
orig_req.RequestURI = ""
|
|
orig_req.Header.Set("Connection", "close")
|
|
rawreq, _ = httputil.DumpRequest(orig_req, false)
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Read tunneled response
|
|
resp, err = http.ReadResponse(bufrd, orig_req)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
|
delHopHeaders(resp.Header)
|
|
copyHeader(wr.Header(), resp.Header)
|
|
wr.WriteHeader(resp.StatusCode)
|
|
io.Copy(wr, resp.Body)
|
|
}
|
|
}
|