From 6327cf74342ee6113412c17e2bbb88f592836104 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Fri, 15 Apr 2022 00:29:21 +0800 Subject: [PATCH] Chore: adjust mitm proxy --- listener/mitm/client.go | 2 - listener/mitm/proxy.go | 100 +++++++++++++++++---------------------- listener/mitm/session.go | 7 +++ listener/mitm/utils.go | 7 +++ 4 files changed, 58 insertions(+), 58 deletions(-) diff --git a/listener/mitm/client.go b/listener/mitm/client.go index 278de173..a9ef56d1 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -13,8 +13,6 @@ import ( "github.com/Dreamacro/clash/transport/socks5" ) -var ErrCertUnsupported = errors.New("tls: client cert unsupported") - func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { return &http.Client{ Transport: &http.Transport{ diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 88fef3db..ce5e8f2e 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -13,7 +13,6 @@ import ( "strings" "time" - "github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/common/cache" N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" @@ -63,7 +62,7 @@ readLoop: session := newSession(conn, request, response) - source = parseSourceAddress(session.request, c, source) + source = parseSourceAddress(session.request, c.RemoteAddr(), source) session.request.RemoteAddr = source.String() if !trusted { @@ -80,42 +79,45 @@ readLoop: break readLoop // close connection } - if couldBeWithManInTheMiddleAttack(session.request.URL.Host, opt) { - b := make([]byte, 1) - if _, err = session.conn.Read(b); err != nil { - handleError(opt, session, err) + if strings.HasSuffix(session.request.URL.Host, ":80") { + goto readLoop + } + + b := make([]byte, 1) + if _, err = session.conn.Read(b); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + + buff := make([]byte, session.conn.(*N.BufferedConn).Buffered()) + if _, err = session.conn.Read(buff); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + + mrConn := &multiReaderConn{ + Conn: session.conn, + reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn), + } + + // TLS handshake. + if b[0] == 0x16 { + // TODO serve by generic host name maybe better? + tlsConn := tls.Server(mrConn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + + // Handshake with the local client + if err = tlsConn.Handshake(); err != nil { + session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) + _ = writeResponse(session, false) break readLoop // close connection } - buff := make([]byte, session.conn.(*N.BufferedConn).Buffered()) - _, _ = session.conn.Read(buff) - - mrc := &multiReaderConn{ - Conn: session.conn, - reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn), - } - - // TLS handshake. - if b[0] == 0x16 { - // TODO serve by generic host name maybe better? - tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) - - // Handshake with the local client - if err = tlsConn.Handshake(); err != nil { - handleError(opt, session, err) - break readLoop // close connection - } - - c = tlsConn - goto startOver // hijack and decrypt tls connection - } - - // maybe it's the others encrypted connection - in <- inbound.NewHTTPS(session.request, mrc) + c = tlsConn + } else { + c = mrConn } - // maybe it's a http connection - goto readLoop + goto startOver } prepareRequest(c, session.request) @@ -149,7 +151,7 @@ readLoop: session.request.RequestURI = "" if session.request.URL.Host == "" { - session.response = session.NewErrorResponse(errors.New("invalid URL")) + session.response = session.NewErrorResponse(ErrInvalidURL) } else { client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) @@ -202,9 +204,7 @@ func writeResponse(session *Session, keepAlive bool) error { session.response.Header.Set("Keep-Alive", "timeout=25") } - // session.response.Close = !keepAlive // let handler do it - - return session.response.Write(session.conn) + return session.writeResponse() } func handleApiRequest(session *Session, opt *Option) error { @@ -224,7 +224,7 @@ func handleApiRequest(session *Session, opt *Option) error { session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") session.response.ContentLength = int64(len(b)) - return session.response.Write(session.conn) + return session.writeResponse() } b := ` @@ -254,7 +254,7 @@ func handleApiRequest(session *Session, opt *Option) error { session.response.Header.Set("Content-Type", "text/html;charset=utf-8") session.response.ContentLength = int64(len(b)) - return session.response.Write(session.conn) + return session.writeResponse() } func handleError(opt *Option, session *Session, err error) { @@ -292,38 +292,26 @@ func prepareRequest(conn net.Conn, request *http.Request) { H.RemoveExtraHTTPHostPort(request) } -func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { - if opt.CertConfig == nil { - return false - } - - if _, port, err := net.SplitHostPort(hostname); err == nil && (port == "443" || port == "8443") { - return true - } - - return false -} - -func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr { +func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr { if source != nil { return source } sourceAddress := req.Header.Get("Origin-Request-Source-Address") if sourceAddress == "" { - return c.RemoteAddr() + return connSource } req.Header.Del("Origin-Request-Source-Address") host, port, err := net.SplitHostPort(sourceAddress) if err != nil { - return c.RemoteAddr() + return connSource } p, err := strconv.ParseUint(port, 10, 16) if err != nil { - return c.RemoteAddr() + return connSource } if ip := net.ParseIP(host); ip != nil { @@ -333,7 +321,7 @@ func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr } } - return c.RemoteAddr() + return connSource } func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { diff --git a/listener/mitm/session.go b/listener/mitm/session.go index 42c7faf7..c2622a69 100644 --- a/listener/mitm/session.go +++ b/listener/mitm/session.go @@ -39,6 +39,13 @@ func (s *Session) NewErrorResponse(err error) *http.Response { return NewErrorResponse(s.request, err) } +func (s *Session) writeResponse() error { + if s.response == nil { + return ErrInvalidResponse + } + return s.response.Write(s.conn) +} + func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session { return &Session{ conn: conn, diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index 8ca8054d..5c3b15bd 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -3,6 +3,7 @@ package mitm import ( "bytes" "compress/gzip" + "errors" "fmt" "io" "io/ioutil" @@ -14,6 +15,12 @@ import ( "golang.org/x/text/transform" ) +var ( + ErrCertUnsupported = errors.New("tls: client cert unsupported") + ErrInvalidResponse = errors.New("invalid response") + ErrInvalidURL = errors.New("invalid URL") +) + type multiReaderConn struct { net.Conn reader io.Reader