mirror of
https://github.com/Snawoot/hola-proxy.git
synced 2026-04-02 10:28:12 +00:00
rework proxy handler on new stack
This commit is contained in:
275
handler.go
275
handler.go
@@ -1,50 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const BAD_REQ_MSG = "Bad Request\n"
|
||||
|
||||
type AuthProvider func() string
|
||||
|
||||
type ProxyHandler struct {
|
||||
auth AuthProvider
|
||||
upstreamAddr string
|
||||
tlsName string
|
||||
logger *CondLogger
|
||||
dialer *net.Dialer
|
||||
dialer ContextDialer
|
||||
httptransport http.RoundTripper
|
||||
resolver *Resolver
|
||||
}
|
||||
|
||||
func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
|
||||
func NewProxyHandler(dialer ContextDialer, resolver *Resolver, logger *CondLogger) *ProxyHandler {
|
||||
httptransport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
Proxy: http.ProxyURL(upstream.URL()),
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", netaddr)
|
||||
},
|
||||
DialContext: dialer.DialContext,
|
||||
}
|
||||
return &ProxyHandler{
|
||||
auth: auth,
|
||||
upstreamAddr: netaddr,
|
||||
tlsName: upstream.TLSName,
|
||||
logger: logger,
|
||||
dialer: dialer,
|
||||
httptransport: httptransport,
|
||||
@@ -52,101 +34,16 @@ func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL)
|
||||
if strings.ToUpper(req.Method) == "CONNECT" {
|
||||
req.Header.Set("Proxy-Authorization", s.auth())
|
||||
rawreq, err := httputil.DumpRequest(req, false)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dump request: %v", err)
|
||||
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial upstream: %v", err)
|
||||
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't write upstream: %v", err)
|
||||
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
bufrd := bufio.NewReader(conn)
|
||||
proxyResp, err := http.ReadResponse(bufrd, req)
|
||||
responseBytes := make([]byte, 0)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't read response from upstream: %v", err)
|
||||
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if proxyResp.StatusCode == http.StatusForbidden &&
|
||||
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
||||
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
|
||||
req.URL.String())
|
||||
|
||||
conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial upstream: %v", err)
|
||||
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
err = rewriteConnectReq(req, s.resolver)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't rewrite request: %v", err)
|
||||
http.Error(wr, "Can't rewrite request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rawreq, err = httputil.DumpRequest(req, false)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dump request: %v", err)
|
||||
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't write tls upstream: %v", err)
|
||||
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
responseBytes, err = httputil.DumpResponse(proxyResp, false)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dump response: %v", err)
|
||||
http.Error(wr, "Can't dump response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
buffered := bufrd.Buffered()
|
||||
if buffered > 0 {
|
||||
trailer := make([]byte, buffered)
|
||||
bufrd.Read(trailer)
|
||||
responseBytes = append(responseBytes, trailer...)
|
||||
}
|
||||
}
|
||||
bufrd = nil
|
||||
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't satisfy CONNECT request: %v", err)
|
||||
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
|
||||
// Upgrade client connection
|
||||
localconn, _, err := hijack(wr)
|
||||
if err != nil {
|
||||
@@ -156,102 +53,56 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
defer localconn.Close()
|
||||
|
||||
if len(responseBytes) > 0 {
|
||||
_, err = localconn.Write(responseBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Inform client connection is built
|
||||
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
|
||||
|
||||
proxy(req.Context(), localconn, conn)
|
||||
} else if req.ProtoMajor == 2 {
|
||||
wr.Header()["Date"] = nil
|
||||
wr.WriteHeader(http.StatusOK)
|
||||
flush(wr)
|
||||
proxyh2(req.Context(), req.Body, wr, conn)
|
||||
} else {
|
||||
delHopHeaders(req.Header)
|
||||
orig_req := req.Clone(req.Context())
|
||||
req.RequestURI = ""
|
||||
req.Header.Set("Proxy-Authorization", s.auth())
|
||||
resp, err := s.httptransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
s.logger.Error("HTTP fetch error: %v", err)
|
||||
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode == http.StatusForbidden &&
|
||||
resp.Header.Get("X-Hola-Error") == "Forbidden Host" {
|
||||
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&tunnel workaround.",
|
||||
req.URL.String())
|
||||
resp.Body.Close()
|
||||
|
||||
// Prepare tunnel request
|
||||
proxyReq, err := makeConnReq(orig_req.RequestURI, s.resolver)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't rewrite request: %v", err)
|
||||
http.Error(wr, "Can't rewrite request", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
proxyReq.Header.Set("Proxy-Authorization", s.auth())
|
||||
rawreq, _ := httputil.DumpRequest(proxyReq, false)
|
||||
|
||||
// Prepare upstream conn
|
||||
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't dial tls upstream: %v", err)
|
||||
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if s.tlsName != "" {
|
||||
conn = tls.Client(conn, &tls.Config{
|
||||
ServerName: s.tlsName,
|
||||
})
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
// Send proxy request
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't write tls upstream: %v", err)
|
||||
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Read proxy response
|
||||
bufrd := bufio.NewReader(conn)
|
||||
proxyResp, err := http.ReadResponse(bufrd, proxyReq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't read response from upstream: %v", err)
|
||||
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if proxyResp.StatusCode != http.StatusOK {
|
||||
delHopHeaders(proxyResp.Header)
|
||||
copyHeader(wr.Header(), proxyResp.Header)
|
||||
wr.WriteHeader(proxyResp.StatusCode)
|
||||
}
|
||||
|
||||
// Send tunneled request
|
||||
orig_req.RequestURI = ""
|
||||
orig_req.Header.Set("Connection", "close")
|
||||
rawreq, _ = httputil.DumpRequest(orig_req, false)
|
||||
_, err = conn.Write(rawreq)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't write tls upstream: %v", err)
|
||||
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Read tunneled response
|
||||
resp, err = http.ReadResponse(bufrd, orig_req)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't read response from upstream: %v", err)
|
||||
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
||||
delHopHeaders(resp.Header)
|
||||
copyHeader(wr.Header(), resp.Header)
|
||||
wr.WriteHeader(resp.StatusCode)
|
||||
io.Copy(wr, resp.Body)
|
||||
s.logger.Error("Unsupported protocol version: %s", req.Proto)
|
||||
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
|
||||
req.RequestURI = ""
|
||||
if req.ProtoMajor == 2 {
|
||||
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
|
||||
req.URL.Host = req.Host
|
||||
}
|
||||
resp, err := s.httptransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
s.logger.Error("HTTP fetch error: %v", err)
|
||||
http.Error(wr, "Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
|
||||
delHopHeaders(resp.Header)
|
||||
copyHeader(wr.Header(), resp.Header)
|
||||
wr.WriteHeader(resp.StatusCode)
|
||||
flush(wr)
|
||||
copyBody(wr, resp.Body)
|
||||
}
|
||||
|
||||
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)
|
||||
|
||||
isConnect := strings.ToUpper(req.Method) == "CONNECT"
|
||||
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
|
||||
req.Host == "" && req.ProtoMajor == 2 {
|
||||
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
delHopHeaders(req.Header)
|
||||
if isConnect {
|
||||
s.HandleTunnel(wr, req)
|
||||
} else {
|
||||
s.HandleRequest(wr, req)
|
||||
}
|
||||
}
|
||||
|
||||
11
main.go
11
main.go
@@ -4,10 +4,9 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
// "os/signal"
|
||||
// "syscall"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -126,9 +125,13 @@ func run() int {
|
||||
logWriter.Close()
|
||||
return 5
|
||||
}
|
||||
var dialer ContextDialer = NewProxyDialer(endpoint.NetAddr(), endpoint.TLSName, auth, &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
})
|
||||
mainLogger.Info("Endpoint: %s", endpoint.URL().String())
|
||||
mainLogger.Info("Starting proxy server...")
|
||||
handler := NewProxyHandler(endpoint, auth, resolver, proxyLogger)
|
||||
handler := NewProxyHandler(dialer, resolver, proxyLogger)
|
||||
mainLogger.Info("Init complete.")
|
||||
err = http.ListenAndServe(args.bind_address, handler)
|
||||
mainLogger.Critical("Server terminated with a reason: %v", err)
|
||||
|
||||
60
utils.go
60
utils.go
@@ -18,6 +18,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const COPY_BUF = 128 * 1024
|
||||
|
||||
type Endpoint struct {
|
||||
Host string
|
||||
Port uint16
|
||||
@@ -38,6 +40,10 @@ func (e *Endpoint) URL() *url.URL {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Endpoint) NetAddr() string {
|
||||
return net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port))
|
||||
}
|
||||
|
||||
func basic_auth_header(login, password string) string {
|
||||
return "basic " + base64.StdEncoding.EncodeToString(
|
||||
[]byte(login+":"+password))
|
||||
@@ -69,6 +75,36 @@ func proxy(ctx context.Context, left, right net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
||||
wg := sync.WaitGroup{}
|
||||
ltr := func(dst net.Conn, src io.Reader) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
rtl := func(dst io.Writer, src io.Reader) {
|
||||
defer wg.Done()
|
||||
copyBody(dst, src)
|
||||
}
|
||||
wg.Add(2)
|
||||
go ltr(right, leftreader)
|
||||
go rtl(leftwriter, right)
|
||||
groupdone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
groupdone <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
leftreader.Close()
|
||||
right.Close()
|
||||
case <-groupdone:
|
||||
return
|
||||
}
|
||||
<-groupdone
|
||||
return
|
||||
}
|
||||
|
||||
func print_countries(timeout time.Duration) int {
|
||||
var (
|
||||
countries CountryList
|
||||
@@ -233,6 +269,30 @@ func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) {
|
||||
return conn, rw, nil
|
||||
}
|
||||
|
||||
func flush(flusher interface{}) bool {
|
||||
f, ok := flusher.(http.Flusher)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
f.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
func copyBody(wr io.Writer, body io.Reader) {
|
||||
buf := make([]byte, COPY_BUF)
|
||||
for {
|
||||
bread, read_err := body.Read(buf)
|
||||
var write_err error
|
||||
if bread > 0 {
|
||||
_, write_err = wr.Write(buf[:bread])
|
||||
flush(wr)
|
||||
}
|
||||
if read_err != nil || write_err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteConnectReq(req *http.Request, resolver *Resolver) error {
|
||||
origHost := req.Host
|
||||
origAddr, origPort, err := net.SplitHostPort(origHost)
|
||||
|
||||
Reference in New Issue
Block a user