chore: packet deadline support CreateReadWaiter interface

This commit is contained in:
wwqgtxx
2023-05-20 11:44:11 +08:00
parent 2b1e69153b
commit b047ca0294
3 changed files with 171 additions and 70 deletions

View File

@@ -2,10 +2,13 @@ package deadline
import (
"os"
"runtime"
"github.com/Dreamacro/clash/common/net/packet"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type SingPacketConn struct {
@@ -33,6 +36,7 @@ var _ packet.EnhanceSingPacketConn = (*EnhanceSingPacketConn)(nil)
type singReadResult struct {
buffer *buf.Buffer
destination M.Socksaddr
err error
}
type singPacketConn struct {
@@ -41,26 +45,34 @@ type singPacketConn struct {
}
func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
destination = result.destination
err = result.err
buffer.Resize(result.buffer.Start(), 0)
n := copy(buffer.FreeBytes(), result.buffer.Bytes())
buffer.Truncate(n)
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*singReadResult); ok {
destination = result.destination
err = result.err
buffer.Resize(result.buffer.Start(), 0)
n := copy(buffer.FreeBytes(), result.buffer.Bytes())
buffer.Truncate(n)
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
}
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
c.netPacketConn.resultCh <- nil // finish cache read
return
} else {
c.netPacketConn.resultCh <- nil
break
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
if c.netPacketConn.disablePipe.Load() {
@@ -82,15 +94,87 @@ func (c *singPacketConn) pipeReadPacket(bufLen int, bufStart int) {
buffer := buf.NewSize(bufLen)
buffer.Advance(bufStart)
destination, err := c.singPacketConn.ReadPacket(buffer)
c.netPacketConn.resultCh <- &readResult{
singReadResult: singReadResult{
buffer: buffer,
destination: destination,
},
err: err,
}
result := &singReadResult{}
result.destination = destination
result.err = err
c.netPacketConn.resultCh <- result
}
func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.singPacketConn.WritePacket(buffer, destination)
}
func (c *singPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
prw, isReadWaiter := bufio.CreatePacketReadWaiter(c.singPacketConn)
if isReadWaiter {
return &singPacketReadWaiter{
netPacketConn: c.netPacketConn,
packetReadWaiter: prw,
}, true
}
return nil, false
}
var _ N.PacketReadWaiter = (*singPacketReadWaiter)(nil)
type singPacketReadWaiter struct {
netPacketConn *NetPacketConn
packetReadWaiter N.PacketReadWaiter
}
type singWaitReadResult singReadResult
func (c *singPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.packetReadWaiter.InitializeReadWaiter(newBuffer)
}
func (c *singPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*singWaitReadResult); ok {
destination = result.destination
err = result.err
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
if c.netPacketConn.disablePipe.Load() {
return c.packetReadWaiter.WaitReadPacket()
} else if c.netPacketConn.deadline.Load().IsZero() {
c.netPacketConn.inRead.Store(true)
defer c.netPacketConn.inRead.Store(false)
destination, err = c.packetReadWaiter.WaitReadPacket()
return
}
<-c.netPacketConn.resultCh
go c.pipeWaitReadPacket()
return c.WaitReadPacket()
}
func (c *singPacketReadWaiter) pipeWaitReadPacket() {
destination, err := c.packetReadWaiter.WaitReadPacket()
result := &singWaitReadResult{}
result.destination = destination
result.err = err
c.netPacketConn.resultCh <- result
}
func (c *singPacketReadWaiter) Upstream() any {
return c.packetReadWaiter
}