From 05ac2bc1460e61fe116c4925ab62a4f04aa37056 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Tue, 23 Mar 2021 17:34:12 +0200 Subject: [PATCH] rework proxy handler on new stack --- handler.go | 275 ++++++++++++----------------------------------------- main.go | 11 ++- utils.go | 60 ++++++++++++ 3 files changed, 130 insertions(+), 216 deletions(-) diff --git a/handler.go b/handler.go index 153fe23..996e32e 100644 --- a/handler.go +++ b/handler.go @@ -1,50 +1,32 @@ package main import ( - "bufio" - "context" - "crypto/tls" "fmt" - "io" - "net" "net/http" - "net/http/httputil" "strings" "time" ) +const BAD_REQ_MSG = "Bad Request\n" + type AuthProvider func() string type ProxyHandler struct { - auth AuthProvider - upstreamAddr string - tlsName string logger *CondLogger - dialer *net.Dialer + dialer ContextDialer httptransport http.RoundTripper resolver *Resolver } -func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler { - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } - netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port)) +func NewProxyHandler(dialer ContextDialer, resolver *Resolver, logger *CondLogger) *ProxyHandler { httptransport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - Proxy: http.ProxyURL(upstream.URL()), - DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, "tcp", netaddr) - }, + DialContext: dialer.DialContext, } return &ProxyHandler{ - auth: auth, - upstreamAddr: netaddr, - tlsName: upstream.TLSName, logger: logger, dialer: dialer, httptransport: httptransport, @@ -52,101 +34,16 @@ func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, } } -func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { - s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL) - if strings.ToUpper(req.Method) == "CONNECT" { - req.Header.Set("Proxy-Authorization", s.auth()) - rawreq, err := httputil.DumpRequest(req, false) - if err != nil { - s.logger.Error("Can't dump request: %v", err) - http.Error(wr, "Can't dump request", http.StatusInternalServerError) - return - } - - conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) - if err != nil { - s.logger.Error("Can't dial upstream: %v", err) - http.Error(wr, "Can't dial upstream", http.StatusBadGateway) - return - } - defer conn.Close() - - if s.tlsName != "" { - conn = tls.Client(conn, &tls.Config{ - ServerName: s.tlsName, - }) - defer conn.Close() - } - - _, err = conn.Write(rawreq) - if err != nil { - s.logger.Error("Can't write upstream: %v", err) - http.Error(wr, "Can't write upstream", http.StatusBadGateway) - return - } - bufrd := bufio.NewReader(conn) - proxyResp, err := http.ReadResponse(bufrd, req) - responseBytes := make([]byte, 0) - if err != nil { - s.logger.Error("Can't read response from upstream: %v", err) - http.Error(wr, "Can't read response from upstream", http.StatusBadGateway) - return - } - - if proxyResp.StatusCode == http.StatusForbidden && - proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" { - s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.", - req.URL.String()) - - conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) - if err != nil { - s.logger.Error("Can't dial upstream: %v", err) - http.Error(wr, "Can't dial upstream", http.StatusBadGateway) - return - } - defer conn.Close() - - if s.tlsName != "" { - conn = tls.Client(conn, &tls.Config{ - ServerName: s.tlsName, - }) - defer conn.Close() - } - - err = rewriteConnectReq(req, s.resolver) - if err != nil { - s.logger.Error("Can't rewrite request: %v", err) - http.Error(wr, "Can't rewrite request", http.StatusInternalServerError) - return - } - rawreq, err = httputil.DumpRequest(req, false) - if err != nil { - s.logger.Error("Can't dump request: %v", err) - http.Error(wr, "Can't dump request", http.StatusInternalServerError) - return - } - _, err = conn.Write(rawreq) - if err != nil { - s.logger.Error("Can't write tls upstream: %v", err) - http.Error(wr, "Can't write tls upstream", http.StatusBadGateway) - return - } - } else { - responseBytes, err = httputil.DumpResponse(proxyResp, false) - if err != nil { - s.logger.Error("Can't dump response: %v", err) - http.Error(wr, "Can't dump response", http.StatusInternalServerError) - return - } - buffered := bufrd.Buffered() - if buffered > 0 { - trailer := make([]byte, buffered) - bufrd.Read(trailer) - responseBytes = append(responseBytes, trailer...) - } - } - bufrd = nil +func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { + ctx := req.Context() + conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI) + if err != nil { + s.logger.Error("Can't satisfy CONNECT request: %v", err) + http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway) + return + } + if req.ProtoMajor == 0 || req.ProtoMajor == 1 { // Upgrade client connection localconn, _, err := hijack(wr) if err != nil { @@ -156,102 +53,56 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } defer localconn.Close() - if len(responseBytes) > 0 { - _, err = localconn.Write(responseBytes) - if err != nil { - return - } - } + // Inform client connection is built + fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) + proxy(req.Context(), localconn, conn) + } else if req.ProtoMajor == 2 { + wr.Header()["Date"] = nil + wr.WriteHeader(http.StatusOK) + flush(wr) + proxyh2(req.Context(), req.Body, wr, conn) } else { - delHopHeaders(req.Header) - orig_req := req.Clone(req.Context()) - req.RequestURI = "" - req.Header.Set("Proxy-Authorization", s.auth()) - resp, err := s.httptransport.RoundTrip(req) - if err != nil { - s.logger.Error("HTTP fetch error: %v", err) - http.Error(wr, "Server Error", http.StatusInternalServerError) - return - } - if resp.StatusCode == http.StatusForbidden && - resp.Header.Get("X-Hola-Error") == "Forbidden Host" { - s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&tunnel workaround.", - req.URL.String()) - resp.Body.Close() - - // Prepare tunnel request - proxyReq, err := makeConnReq(orig_req.RequestURI, s.resolver) - if err != nil { - s.logger.Error("Can't rewrite request: %v", err) - http.Error(wr, "Can't rewrite request", http.StatusBadGateway) - return - } - proxyReq.Header.Set("Proxy-Authorization", s.auth()) - rawreq, _ := httputil.DumpRequest(proxyReq, false) - - // Prepare upstream conn - conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr) - if err != nil { - s.logger.Error("Can't dial tls upstream: %v", err) - http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway) - return - } - defer conn.Close() - - if s.tlsName != "" { - conn = tls.Client(conn, &tls.Config{ - ServerName: s.tlsName, - }) - defer conn.Close() - } - - // Send proxy request - _, err = conn.Write(rawreq) - if err != nil { - s.logger.Error("Can't write tls upstream: %v", err) - http.Error(wr, "Can't write tls upstream", http.StatusBadGateway) - return - } - - // Read proxy response - bufrd := bufio.NewReader(conn) - proxyResp, err := http.ReadResponse(bufrd, proxyReq) - if err != nil { - s.logger.Error("Can't read response from upstream: %v", err) - http.Error(wr, "Can't read response from upstream", http.StatusBadGateway) - return - } - if proxyResp.StatusCode != http.StatusOK { - delHopHeaders(proxyResp.Header) - copyHeader(wr.Header(), proxyResp.Header) - wr.WriteHeader(proxyResp.StatusCode) - } - - // Send tunneled request - orig_req.RequestURI = "" - orig_req.Header.Set("Connection", "close") - rawreq, _ = httputil.DumpRequest(orig_req, false) - _, err = conn.Write(rawreq) - if err != nil { - s.logger.Error("Can't write tls upstream: %v", err) - http.Error(wr, "Can't write tls upstream", http.StatusBadGateway) - return - } - - // Read tunneled response - resp, err = http.ReadResponse(bufrd, orig_req) - if err != nil { - s.logger.Error("Can't read response from upstream: %v", err) - http.Error(wr, "Can't read response from upstream", http.StatusBadGateway) - return - } - } - defer resp.Body.Close() - s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status) - delHopHeaders(resp.Header) - copyHeader(wr.Header(), resp.Header) - wr.WriteHeader(resp.StatusCode) - io.Copy(wr, resp.Body) + s.logger.Error("Unsupported protocol version: %s", req.Proto) + http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) + return + } +} + +func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) { + req.RequestURI = "" + if req.ProtoMajor == 2 { + req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http + req.URL.Host = req.Host + } + resp, err := s.httptransport.RoundTrip(req) + if err != nil { + s.logger.Error("HTTP fetch error: %v", err) + http.Error(wr, "Server Error", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status) + delHopHeaders(resp.Header) + copyHeader(wr.Header(), resp.Header) + wr.WriteHeader(resp.StatusCode) + flush(wr) + copyBody(wr, resp.Body) +} + +func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { + s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL) + + isConnect := strings.ToUpper(req.Method) == "CONNECT" + if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 || + req.Host == "" && req.ProtoMajor == 2 { + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return + } + delHopHeaders(req.Header) + if isConnect { + s.HandleTunnel(wr, req) + } else { + s.HandleRequest(wr, req) } } diff --git a/main.go b/main.go index c58989e..e7ee7a1 100644 --- a/main.go +++ b/main.go @@ -4,10 +4,9 @@ import ( "flag" "fmt" "log" - "os" - // "os/signal" - // "syscall" + "net" "net/http" + "os" "time" ) @@ -126,9 +125,13 @@ func run() int { logWriter.Close() return 5 } + var dialer ContextDialer = NewProxyDialer(endpoint.NetAddr(), endpoint.TLSName, auth, &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }) mainLogger.Info("Endpoint: %s", endpoint.URL().String()) mainLogger.Info("Starting proxy server...") - handler := NewProxyHandler(endpoint, auth, resolver, proxyLogger) + handler := NewProxyHandler(dialer, resolver, proxyLogger) mainLogger.Info("Init complete.") err = http.ListenAndServe(args.bind_address, handler) mainLogger.Critical("Server terminated with a reason: %v", err) diff --git a/utils.go b/utils.go index 8ea99e1..92083c1 100644 --- a/utils.go +++ b/utils.go @@ -18,6 +18,8 @@ import ( "time" ) +const COPY_BUF = 128 * 1024 + type Endpoint struct { Host string Port uint16 @@ -38,6 +40,10 @@ func (e *Endpoint) URL() *url.URL { } } +func (e *Endpoint) NetAddr() string { + return net.JoinHostPort(e.Host, fmt.Sprintf("%d", e.Port)) +} + func basic_auth_header(login, password string) string { return "basic " + base64.StdEncoding.EncodeToString( []byte(login+":"+password)) @@ -69,6 +75,36 @@ func proxy(ctx context.Context, left, right net.Conn) { return } +func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { + wg := sync.WaitGroup{} + ltr := func(dst net.Conn, src io.Reader) { + defer wg.Done() + io.Copy(dst, src) + dst.Close() + } + rtl := func(dst io.Writer, src io.Reader) { + defer wg.Done() + copyBody(dst, src) + } + wg.Add(2) + go ltr(right, leftreader) + go rtl(leftwriter, right) + groupdone := make(chan struct{}, 1) + go func() { + wg.Wait() + groupdone <- struct{}{} + }() + select { + case <-ctx.Done(): + leftreader.Close() + right.Close() + case <-groupdone: + return + } + <-groupdone + return +} + func print_countries(timeout time.Duration) int { var ( countries CountryList @@ -233,6 +269,30 @@ func hijack(hijackable interface{}) (net.Conn, *bufio.ReadWriter, error) { return conn, rw, nil } +func flush(flusher interface{}) bool { + f, ok := flusher.(http.Flusher) + if !ok { + return false + } + f.Flush() + return true +} + +func copyBody(wr io.Writer, body io.Reader) { + buf := make([]byte, COPY_BUF) + for { + bread, read_err := body.Read(buf) + var write_err error + if bread > 0 { + _, write_err = wr.Write(buf[:bread]) + flush(wr) + } + if read_err != nil || write_err != nil { + break + } + } +} + func rewriteConnectReq(req *http.Request, resolver *Resolver) error { origHost := req.Host origAddr, origPort, err := net.SplitHostPort(origHost)