Feature: MITM rewrite

This commit is contained in:
yaling888
2022-04-10 03:59:27 +08:00
parent 5a27ebd1b3
commit f036e06f6f
40 changed files with 2144 additions and 71 deletions

View File

@@ -42,7 +42,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
var resp *http.Response
if !trusted {
resp = authenticate(request, cache)
resp = Authenticate(request, cache)
trusted = resp == nil
}
@@ -66,19 +66,19 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
request.RequestURI = ""
removeHopByHopHeaders(request.Header)
removeExtraHTTPHostPort(request)
RemoveHopByHopHeaders(request.Header)
RemoveExtraHTTPHostPort(request)
if request.URL.Scheme == "" || request.URL.Host == "" {
resp = responseWith(request, http.StatusBadRequest)
resp = ResponseWith(request, http.StatusBadRequest)
} else {
resp, err = client.Do(request)
if err != nil {
resp = responseWith(request, http.StatusBadGateway)
resp = ResponseWith(request, http.StatusBadGateway)
}
}
removeHopByHopHeaders(resp.Header)
RemoveHopByHopHeaders(resp.Header)
}
if keepAlive {
@@ -98,12 +98,12 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
conn.Close()
}
func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response {
func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response {
authenticator := authStore.Authenticator()
if authenticator != nil {
credential := parseBasicProxyAuthorization(request)
if credential == "" {
resp := responseWith(request, http.StatusProxyAuthRequired)
resp := ResponseWith(request, http.StatusProxyAuthRequired)
resp.Header.Set("Proxy-Authenticate", "Basic")
return resp
}
@@ -117,14 +117,14 @@ func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http
if !authed {
log.Infoln("Auth failed from %s", request.RemoteAddr)
return responseWith(request, http.StatusForbidden)
return ResponseWith(request, http.StatusForbidden)
}
}
return nil
}
func responseWith(request *http.Request, statusCode int) *http.Response {
func ResponseWith(request *http.Request, statusCode int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Status: http.StatusText(statusCode),

View File

@@ -8,8 +8,8 @@ import (
"strings"
)
// removeHopByHopHeaders remove hop-by-hop header
func removeHopByHopHeaders(header http.Header) {
// RemoveHopByHopHeaders remove hop-by-hop header
func RemoveHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
@@ -32,9 +32,9 @@ func removeHopByHopHeaders(header http.Header) {
}
}
// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
// RemoveExtraHTTPHostPort remove extra host port (example.com:80 --> example.com)
// It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com)
func removeExtraHTTPHostPort(req *http.Request) {
func RemoveExtraHTTPHostPort(req *http.Request) {
host := req.Host
if host == "" {
host = req.URL.Host

View File

@@ -1,6 +1,9 @@
package proxy
import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
@@ -8,9 +11,12 @@ import (
"sync"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/cert"
"github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/listener/http"
"github.com/Dreamacro/clash/listener/mitm"
"github.com/Dreamacro/clash/listener/mixed"
"github.com/Dreamacro/clash/listener/redir"
"github.com/Dreamacro/clash/listener/socks"
@@ -18,6 +24,8 @@ import (
"github.com/Dreamacro/clash/listener/tun"
"github.com/Dreamacro/clash/listener/tun/ipstack"
"github.com/Dreamacro/clash/log"
rewrites "github.com/Dreamacro/clash/rewrite"
"github.com/Dreamacro/clash/tunnel"
)
var (
@@ -34,6 +42,7 @@ var (
mixedListener *mixed.Listener
mixedUDPLister *socks.UDPListener
tunStackListener ipstack.Stack
mitmListener *mitm.Listener
// lock for recreate function
socksMux sync.Mutex
@@ -42,6 +51,7 @@ var (
tproxyMux sync.Mutex
mixedMux sync.Mutex
tunMux sync.Mutex
mitmMux sync.Mutex
)
type Ports struct {
@@ -50,6 +60,7 @@ type Ports struct {
RedirPort int `json:"redir-port"`
TProxyPort int `json:"tproxy-port"`
MixedPort int `json:"mixed-port"`
MitmPort int `json:"mitm-port"`
}
func AllowLan() bool {
@@ -331,6 +342,85 @@ func ReCreateTun(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.Co
tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn)
}
func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) {
mitmMux.Lock()
defer mitmMux.Unlock()
var err error
defer func() {
if err != nil {
log.Errorln("Start MITM server error: %s", err.Error())
}
}()
addr := genAddr(bindAddress, port, allowLan)
if mitmListener != nil {
if mitmListener.RawAddress() == addr {
return
}
outbound.MiddlemanServerAddress.Store("")
tunnel.MitmOutbound = nil
_ = mitmListener.Close()
mitmListener = nil
}
if portIsZero(addr) {
return
}
if err = initCert(); err != nil {
return
}
var (
rootCACert tls.Certificate
x509c *x509.Certificate
certOption *cert.Config
)
rootCACert, err = tls.LoadX509KeyPair(C.Path.RootCA(), C.Path.CAKey())
if err != nil {
return
}
privateKey := rootCACert.PrivateKey.(*rsa.PrivateKey)
x509c, err = x509.ParseCertificate(rootCACert.Certificate[0])
if err != nil {
return
}
certOption, err = cert.NewConfig(
x509c,
privateKey,
cert.NewAutoGCCertsStorage(),
)
if err != nil {
return
}
certOption.SetValidity(cert.TTL << 3)
certOption.SetOrganization("Clash ManInTheMiddle Proxy Services")
opt := &mitm.Option{
Addr: addr,
ApiHost: "mitm.clash",
CertConfig: certOption,
Handler: &rewrites.RewriteHandler{},
}
mitmListener, err = mitm.New(opt, tcpIn)
if err != nil {
return
}
outbound.MiddlemanServerAddress.Store(mitmListener.Address())
tunnel.MitmOutbound = outbound.NewMitm()
log.Infoln("Mitm proxy listening at: %s", mitmListener.Address())
}
// GetPorts return the ports of proxy servers
func GetPorts() *Ports {
ports := &Ports{}
@@ -365,6 +455,12 @@ func GetPorts() *Ports {
ports.MixedPort = port
}
if mitmListener != nil {
_, portStr, _ := net.SplitHostPort(mitmListener.Address())
port, _ := strconv.Atoi(portStr)
ports.MitmPort = port
}
return ports
}
@@ -387,6 +483,19 @@ func genAddr(host string, port int, allowLan bool) string {
return fmt.Sprintf("127.0.0.1:%d", port)
}
func initCert() error {
if _, err := os.Stat(C.Path.RootCA()); os.IsNotExist(err) {
log.Infoln("Can't find mitm_ca.crt, start generate")
err = cert.GenerateAndSave(C.Path.RootCA(), C.Path.CAKey())
if err != nil {
return err
}
log.Infoln("Generated CA private key and CA certificate finish")
}
return nil
}
func Cleanup() {
if tunStackListener != nil {
_ = tunStackListener.Close()

54
listener/mitm/client.go Normal file
View File

@@ -0,0 +1,54 @@
package mitm
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"time"
"github.com/Dreamacro/clash/adapter/inbound"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)
var ErrCertUnsupported = errors.New("tls: client cert unsupported")
func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client {
return &http.Client{
Transport: &http.Transport{
// excepted HTTP/2
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
// from http.DefaultTransport
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) {
return nil, ErrCertUnsupported
},
},
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, errors.New("unsupported network " + network)
}
dstAddr := socks5.ParseAddr(address)
if dstAddr == nil {
return nil, socks5.ErrAddressNotSupported
}
left, right := net.Pipe()
in <- inbound.NewMitm(dstAddr, source, userAgent, right)
return left, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}

357
listener/mitm/proxy.go Normal file
View File

@@ -0,0 +1,357 @@
package mitm
import (
"bytes"
"crypto/tls"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/cache"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant"
httpL "github.com/Dreamacro/clash/listener/http"
)
func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
var (
source net.Addr
client *http.Client
)
defer func() {
if client != nil {
client.CloseIdleConnections()
}
}()
startOver:
if tc, ok := c.(*net.TCPConn); ok {
_ = tc.SetKeepAlive(true)
}
var conn *N.BufferedConn
if bufConn, ok := c.(*N.BufferedConn); ok {
conn = bufConn
} else {
conn = N.NewBufferedConn(c)
}
trusted := cache == nil // disable authenticate if cache is nil
readLoop:
for {
_ = conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive
request, err := httpL.ReadRequest(conn.Reader())
if err != nil {
handleError(opt, nil, err)
break readLoop
}
var response *http.Response
session := NewSession(conn, request, response)
source = parseSourceAddress(session.request, c, source)
request.RemoteAddr = source.String()
if !trusted {
response = httpL.Authenticate(request, cache)
trusted = response == nil
}
if trusted {
if session.request.Method == http.MethodConnect {
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
if couldBeWithManInTheMiddleAttack(session.request.URL.Host, opt) {
b := make([]byte, 1)
if _, err = session.conn.Read(b); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
buf := make([]byte, session.conn.(*N.BufferedConn).Buffered())
_, _ = session.conn.Read(buf)
mc := &MultiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), session.conn),
}
// 22 is the TLS handshake.
// https://tools.ietf.org/html/rfc5246#section-6.2.1
if b[0] == 22 {
// TODO serve by generic host name maybe better?
tlsConn := tls.Server(mc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client
if err = tlsConn.Handshake(); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
c = tlsConn
goto startOver // hijack and decrypt tls connection
}
// maybe it's the others encrypted connection
in <- inbound.NewHTTPS(request, mc)
}
// maybe it's a http connection
goto readLoop
}
// hijack api
if getHostnameWithoutPort(session.request) == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err)
break readLoop
}
return
}
prepareRequest(c, session.request)
// hijack custom request and write back custom response if necessary
if opt.Handler != nil {
newReq, newRes := opt.Handler.HandleRequest(session)
if newReq != nil {
session.request = newReq
}
if newRes != nil {
session.response = newRes
if err = writeResponse(session, false); err != nil {
handleError(opt, session, err)
break readLoop
}
return
}
}
httpL.RemoveHopByHopHeaders(session.request.Header)
httpL.RemoveExtraHTTPHostPort(request)
session.request.RequestURI = ""
if session.request.URL.Scheme == "" || session.request.URL.Host == "" {
session.response = session.NewErrorResponse(errors.New("invalid URL"))
} else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
// send the request to remote server
session.response, err = client.Do(session.request)
if err != nil {
handleError(opt, session, err)
session.response = session.NewErrorResponse(err)
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
// TODO block unsupported host?
}
}
}
}
if err = writeResponseWithHandler(session, opt); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
}
_ = conn.Close()
}
func writeResponseWithHandler(session *Session, opt *Option) error {
if opt.Handler != nil {
res := opt.Handler.HandleResponse(session)
if res != nil {
body := res.Body
defer func(body io.ReadCloser) {
_ = body.Close()
}(body)
session.response = res
}
}
return writeResponse(session, true)
}
func writeResponse(session *Session, keepAlive bool) error {
httpL.RemoveHopByHopHeaders(session.response.Header)
if keepAlive {
session.response.Header.Set("Connection", "keep-alive")
session.response.Header.Set("Keep-Alive", "timeout=25")
}
// session.response.Close = !keepAlive // let handler do it
return session.response.Write(session.conn)
}
func handleApiRequest(session *Session, opt *Option) error {
if opt.CertConfig != nil && strings.ToLower(session.request.URL.Path) == "/cert.crt" {
b := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: opt.CertConfig.GetCA().Raw,
})
session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b))
defer func(body io.ReadCloser) {
_ = body.Close()
}(session.response.Body)
session.response.Close = true
session.response.Header.Set("Content-Type", "application/x-x509-ca-cert")
session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn)
}
b := `<!DOCTYPE HTML PUBLIC "-
<html>
<head>
<title>Clash ManInTheMiddle Proxy Services - 404 Not Found</title>
</head>
<body>
<h1>Not Found</h1>
<p>The requested URL %s was not found on this server.</p>
</body>
</html>
`
if opt.Handler != nil {
if opt.Handler.HandleApiRequest(session) {
return nil
}
}
b = fmt.Sprintf(b, session.request.URL.Path)
session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b)))
defer func(body io.ReadCloser) {
_ = body.Close()
}(session.response.Body)
session.response.Close = true
session.response.Header.Set("Content-Type", "text/html;charset=utf-8")
session.response.ContentLength = int64(len(b))
return session.response.Write(session.conn)
}
func handleError(opt *Option, session *Session, err error) {
if opt.Handler != nil {
opt.Handler.HandleError(session, err)
return
}
// log.Errorln("[MITM] process mitm error: %v", err)
}
func prepareRequest(conn net.Conn, request *http.Request) {
host := request.Header.Get("Host")
if host != "" {
request.Host = host
}
if request.URL.Host == "" {
request.URL.Host = request.Host
}
request.URL.Scheme = "http"
if tlsConn, ok := conn.(*tls.Conn); ok {
cs := tlsConn.ConnectionState()
request.TLS = &cs
request.URL.Scheme = "https"
}
if request.Header.Get("Accept-Encoding") != "" {
request.Header.Set("Accept-Encoding", "gzip")
}
}
func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool {
if opt.CertConfig == nil {
return false
}
if _, port, err := net.SplitHostPort(hostname); err == nil && (port == "443" || port == "8443") {
return true
}
return false
}
func getHostnameWithoutPort(req *http.Request) string {
host := req.Host
if host == "" {
host = req.URL.Host
}
if pHost, _, err := net.SplitHostPort(host); err == nil {
host = pHost
}
return host
}
func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr {
if source != nil {
return source
}
sourceAddress := req.Header.Get("Origin-Request-Source-Address")
if sourceAddress == "" {
return c.RemoteAddr()
}
req.Header.Del("Origin-Request-Source-Address")
host, port, err := net.SplitHostPort(sourceAddress)
if err != nil {
return c.RemoteAddr()
}
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return c.RemoteAddr()
}
if ip := net.ParseIP(host); ip != nil {
return &net.TCPAddr{
IP: ip,
Port: int(p),
}
}
return c.RemoteAddr()
}
func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client {
if cli != nil {
return cli
}
return newClient(source, req.Header.Get("User-Agent"), in)
}

90
listener/mitm/server.go Normal file
View File

@@ -0,0 +1,90 @@
package mitm
import (
"crypto/tls"
"net"
"net/http"
"time"
"github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/cert"
C "github.com/Dreamacro/clash/constant"
)
type Handler interface {
HandleRequest(*Session) (*http.Request, *http.Response) // Session.Response maybe nil
HandleResponse(*Session) *http.Response
HandleApiRequest(*Session) bool
HandleError(*Session, error) // Session maybe nil
}
type Option struct {
Addr string
ApiHost string
TLSConfig *tls.Config
CertConfig *cert.Config
Handler Handler
}
type Listener struct {
*Option
listener net.Listener
addr string
closed bool
}
// RawAddress implements C.Listener
func (l *Listener) RawAddress() string {
return l.addr
}
// Address implements C.Listener
func (l *Listener) Address() string {
return l.listener.Addr().String()
}
// Close implements C.Listener
func (l *Listener) Close() error {
l.closed = true
return l.listener.Close()
}
// New the MITM proxy actually is a type of HTTP proxy
func New(option *Option, in chan<- C.ConnContext) (*Listener, error) {
return NewWithAuthenticate(option, in, false)
}
func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) {
l, err := net.Listen("tcp", option.Addr)
if err != nil {
return nil, err
}
var c *cache.Cache[string, bool]
if authenticate {
c = cache.New[string, bool](time.Second * 30)
}
hl := &Listener{
listener: l,
addr: option.Addr,
Option: option,
}
go func() {
for {
conn, err1 := hl.listener.Accept()
if err1 != nil {
if hl.closed {
break
}
continue
}
go HandleConn(conn, option, in, c)
}
}()
return hl, nil
}

56
listener/mitm/session.go Normal file
View File

@@ -0,0 +1,56 @@
package mitm
import (
"fmt"
"io"
"net"
"net/http"
C "github.com/Dreamacro/clash/constant"
)
var serverName = fmt.Sprintf("Clash server (%s)", C.Version)
type Session struct {
conn net.Conn
request *http.Request
response *http.Response
props map[string]any
}
func (s *Session) Request() *http.Request {
return s.request
}
func (s *Session) Response() *http.Response {
return s.response
}
func (s *Session) GetProperties(key string) (any, bool) {
v, ok := s.props[key]
return v, ok
}
func (s *Session) SetProperties(key string, val any) {
s.props[key] = val
}
func (s *Session) NewResponse(code int, body io.Reader) *http.Response {
res := NewResponse(code, body, s.request)
res.Header.Set("Server", serverName)
return res
}
func (s *Session) NewErrorResponse(err error) *http.Response {
return NewErrorResponse(s.request, err)
}
func NewSession(conn net.Conn, request *http.Request, response *http.Response) *Session {
return &Session{
conn: conn,
request: request,
response: response,
props: map[string]any{},
}
}

100
listener/mitm/utils.go Normal file
View File

@@ -0,0 +1,100 @@
package mitm
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"time"
"golang.org/x/text/encoding/charmap"
"golang.org/x/text/transform"
)
type MultiReaderConn struct {
net.Conn
reader io.Reader
}
func (c *MultiReaderConn) Read(buf []byte) (int, error) {
return c.reader.Read(buf)
}
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
if body == nil {
body = &bytes.Buffer{}
}
rc, ok := body.(io.ReadCloser)
if !ok {
rc = ioutil.NopCloser(body)
}
res := &http.Response{
StatusCode: code,
Status: fmt.Sprintf("%d %s", code, http.StatusText(code)),
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: rc,
Request: req,
}
if req != nil {
res.Close = req.Close
res.Proto = req.Proto
res.ProtoMajor = req.ProtoMajor
res.ProtoMinor = req.ProtoMinor
}
return res
}
func NewErrorResponse(req *http.Request, err error) *http.Response {
res := NewResponse(http.StatusBadGateway, nil, req)
res.Close = true
date := res.Header.Get("Date")
if date == "" {
date = time.Now().Format(http.TimeFormat)
}
w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date)
res.Header.Add("Warning", w)
res.Header.Set("Server", serverName)
return res
}
func ReadDecompressedBody(res *http.Response) ([]byte, error) {
rBody := res.Body
if res.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(rBody)
if err != nil {
return nil, err
}
rBody = gzReader
defer func(gzReader *gzip.Reader) {
_ = gzReader.Close()
}(gzReader)
}
return ioutil.ReadAll(rBody)
}
func DecodeLatin1(reader io.Reader) (string, error) {
r := transform.NewReader(reader, charmap.ISO8859_1.NewDecoder())
b, err := ioutil.ReadAll(r)
if err != nil {
return "", err
}
return string(b), nil
}
func EncodeLatin1(str string) ([]byte, error) {
return charmap.ISO8859_1.NewEncoder().Bytes([]byte(str))
}