mirror of
https://github.com/Snawoot/hola-proxy.git
synced 2026-04-02 17:48:13 +00:00
221 lines
7.9 KiB
Go
221 lines
7.9 KiB
Go
package main
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"crypto/tls"
|
|
"strings"
|
|
"net/url"
|
|
"bufio"
|
|
)
|
|
|
|
type AuthProvider func() string
|
|
|
|
type ProxyHandler struct {
|
|
auth AuthProvider
|
|
upstream string
|
|
logger *CondLogger
|
|
httptransport http.RoundTripper
|
|
resolver *Resolver
|
|
}
|
|
|
|
func NewProxyHandler(upstream string, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
|
proxyurl, err := url.Parse("https://" + upstream)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
httptransport := &http.Transport{
|
|
Proxy: http.ProxyURL(proxyurl),
|
|
}
|
|
return &ProxyHandler{
|
|
auth: auth,
|
|
upstream: upstream,
|
|
logger: logger,
|
|
httptransport: httptransport,
|
|
resolver: resolver,
|
|
}
|
|
}
|
|
|
|
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
|
s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL)
|
|
if strings.ToUpper(req.Method) == "CONNECT" {
|
|
req.Header.Set("Proxy-Authorization", s.auth())
|
|
rawreq, err := httputil.DumpRequest(req, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump request: %v", err)
|
|
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
conn, err := tls.Dial("tcp", s.upstream, nil)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial tls upstream: %v", err)
|
|
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
bufrd := bufio.NewReader(conn)
|
|
proxyResp, err := http.ReadResponse(bufrd, req)
|
|
responseBytes := make([]byte, 0)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
if (proxyResp.StatusCode == http.StatusForbidden &&
|
|
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host") {
|
|
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
|
|
req.URL.String())
|
|
conn.Close()
|
|
conn, err = tls.Dial("tcp", s.upstream, nil)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial tls upstream: %v", err)
|
|
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
err = rewriteConnectReq(req, s.resolver)
|
|
if err != nil {
|
|
s.logger.Error("Can't rewrite request: %v", err)
|
|
http.Error(wr, "Can't rewrite request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
rawreq, err = httputil.DumpRequest(req, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump request: %v", err)
|
|
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
} else {
|
|
defer conn.Close()
|
|
responseBytes, err = httputil.DumpResponse(proxyResp, false)
|
|
if err != nil {
|
|
s.logger.Error("Can't dump response: %v", err)
|
|
http.Error(wr, "Can't dump response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
buffered := bufrd.Buffered()
|
|
if buffered > 0 {
|
|
trailer := make([]byte, buffered)
|
|
bufrd.Read(trailer)
|
|
responseBytes = append(responseBytes, trailer...)
|
|
}
|
|
}
|
|
bufrd = nil
|
|
|
|
// Upgrade client connection
|
|
localconn, _, err := hijack(wr)
|
|
if err != nil {
|
|
s.logger.Error("Can't hijack client connection: %v", err)
|
|
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer localconn.Close()
|
|
|
|
if len(responseBytes) > 0 {
|
|
_, err = localconn.Write(responseBytes)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
proxy(req.Context(), localconn, conn)
|
|
} else {
|
|
delHopHeaders(req.Header)
|
|
orig_req := req.Clone(req.Context())
|
|
req.RequestURI = ""
|
|
req.Header.Set("Proxy-Authorization", s.auth())
|
|
resp, err := s.httptransport.RoundTrip(req)
|
|
if err != nil {
|
|
s.logger.Error("HTTP fetch error: %v", err)
|
|
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if (resp.StatusCode == http.StatusForbidden &&
|
|
resp.Header.Get("X-Hola-Error") == "Forbidden Host") {
|
|
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&tunnel workaround.",
|
|
req.URL.String())
|
|
resp.Body.Close()
|
|
|
|
// Prepare tunnel request
|
|
proxyReq, err := makeConnReq(orig_req.RequestURI, s.resolver)
|
|
if err != nil {
|
|
s.logger.Error("Can't rewrite request: %v", err)
|
|
http.Error(wr, "Can't rewrite request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
proxyReq.Header.Set("Proxy-Authorization", s.auth())
|
|
rawreq, _ := httputil.DumpRequest(proxyReq, false)
|
|
|
|
// Prepare upstream TLS conn
|
|
conn, err := tls.Dial("tcp", s.upstream, nil)
|
|
if err != nil {
|
|
s.logger.Error("Can't dial tls upstream: %v", err)
|
|
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Send proxy request
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Read proxy response
|
|
bufrd := bufio.NewReader(conn)
|
|
proxyResp, err := http.ReadResponse(bufrd, proxyReq)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
if proxyResp.StatusCode != http.StatusOK {
|
|
delHopHeaders(proxyResp.Header)
|
|
copyHeader(wr.Header(), proxyResp.Header)
|
|
wr.WriteHeader(proxyResp.StatusCode)
|
|
}
|
|
|
|
// Send tunneled request
|
|
orig_req.RequestURI = ""
|
|
orig_req.Header.Set("Connection", "close")
|
|
rawreq, _ = httputil.DumpRequest(orig_req, false)
|
|
_, err = conn.Write(rawreq)
|
|
if err != nil {
|
|
s.logger.Error("Can't write tls upstream: %v", err)
|
|
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// Read tunneled response
|
|
resp, err = http.ReadResponse(bufrd, orig_req)
|
|
if err != nil {
|
|
s.logger.Error("Can't read response from upstream: %v", err)
|
|
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
|
return
|
|
}
|
|
}
|
|
defer resp.Body.Close()
|
|
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
|
delHopHeaders(resp.Header)
|
|
copyHeader(wr.Header(), resp.Header)
|
|
wr.WriteHeader(resp.StatusCode)
|
|
io.Copy(wr, resp.Body)
|
|
}
|
|
}
|