Chore: adjust mitm proxy

This commit is contained in:
yaling888
2022-04-15 00:29:21 +08:00
parent ca76e5cf0e
commit 6327cf7434
4 changed files with 58 additions and 58 deletions

View File

@@ -13,8 +13,6 @@ import (
"github.com/Dreamacro/clash/transport/socks5"
)
var ErrCertUnsupported = errors.New("tls: client cert unsupported")
func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client {
return &http.Client{
Transport: &http.Transport{

View File

@@ -13,7 +13,6 @@ import (
"strings"
"time"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
@@ -63,7 +62,7 @@ readLoop:
session := newSession(conn, request, response)
source = parseSourceAddress(session.request, c, source)
source = parseSourceAddress(session.request, c.RemoteAddr(), source)
session.request.RemoteAddr = source.String()
if !trusted {
@@ -80,42 +79,45 @@ readLoop:
break readLoop // close connection
}
if couldBeWithManInTheMiddleAttack(session.request.URL.Host, opt) {
b := make([]byte, 1)
if _, err = session.conn.Read(b); err != nil {
handleError(opt, session, err)
if strings.HasSuffix(session.request.URL.Host, ":80") {
goto readLoop
}
b := make([]byte, 1)
if _, err = session.conn.Read(b); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
if _, err = session.conn.Read(buff); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
mrConn := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// TLS handshake.
if b[0] == 0x16 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mrConn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
_ = writeResponse(session, false)
break readLoop // close connection
}
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
_, _ = session.conn.Read(buff)
mrc := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// TLS handshake.
if b[0] == 0x16 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mrc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
c = tlsConn
goto startOver // hijack and decrypt tls connection
}
// maybe it's the others encrypted connection
in <- inbound.NewHTTPS(session.request, mrc)
c = tlsConn
} else {
c = mrConn
}
// maybe it's a http connection
goto readLoop
goto startOver
}
prepareRequest(c, session.request)
@@ -149,7 +151,7 @@ readLoop:
session.request.RequestURI = ""
if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(errors.New("invalid URL"))
session.response = session.NewErrorResponse(ErrInvalidURL)
} else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
@@ -202,9 +204,7 @@ func writeResponse(session *Session, keepAlive bool) error {
session.response.Header.Set("Keep-Alive", "timeout=25")
}
// session.response.Close = !keepAlive // let handler do it
return session.response.Write(session.conn)
return session.writeResponse()
}
func handleApiRequest(session *Session, opt *Option) error {
@@ -224,7 +224,7 @@ func handleApiRequest(session *Session, opt *Option) error {
session.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn)
return session.writeResponse()
}
b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
@@ -254,7 +254,7 @@ func handleApiRequest(session *Session, opt *Option) error {
session.response.Header.Set("Content-Type", "text/html;charset=utf-8")
session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn)
return session.writeResponse()
}
func handleError(opt *Option, session *Session, err error) {
@@ -292,38 +292,26 @@ func prepareRequest(conn net.Conn, request *http.Request) {
H.RemoveExtraHTTPHostPort(request)
}
func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
if opt.CertConfig == nil {
return false
}
if _, port, err := net.SplitHostPort(hostname); err == nil && (port == "443" || port == "8443") {
return true
}
return false
}
func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr {
func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr {
if source != nil {
return source
}
sourceAddress := req.Header.Get("Origin-Request-Source-Address")
if sourceAddress == "" {
return c.RemoteAddr()
return connSource
}
req.Header.Del("Origin-Request-Source-Address")
host, port, err := net.SplitHostPort(sourceAddress)
if err != nil {
return c.RemoteAddr()
return connSource
}
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return c.RemoteAddr()
return connSource
}
if ip := net.ParseIP(host); ip != nil {
@@ -333,7 +321,7 @@ func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr
}
}
return c.RemoteAddr()
return connSource
}
func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client {

View File

@@ -39,6 +39,13 @@ func (s *Session) NewErrorResponse(err error) *http.Response {
return NewErrorResponse(s.request, err)
}
func (s *Session) writeResponse() error {
if s.response == nil {
return ErrInvalidResponse
}
return s.response.Write(s.conn)
}
func newSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
return &Session{
conn: conn,

View File

@@ -3,6 +3,7 @@ package mitm
import (
"bytes"
"compress/gzip"
"errors"
"fmt"
"io"
"io/ioutil"
@@ -14,6 +15,12 @@ import (
"golang.org/x/text/transform"
)
var (
ErrCertUnsupported = errors.New("tls: client cert unsupported")
ErrInvalidResponse = errors.New("invalid response")
ErrInvalidURL = errors.New("invalid URL")
)
type multiReaderConn struct {
net.Conn
reader io.Reader