Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shadowsocks2022 Client Implementation Improvements #2770

Merged
merged 7 commits into from
Nov 24, 2023
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/miekg/dns v1.1.57
github.com/mustafaturan/bus v1.0.2
github.com/pelletier/go-toml v1.9.5
github.com/pion/transport/v2 v2.2.1
github.com/pires/go-proxyproto v0.7.0
github.com/quic-go/quic-go v0.40.0
github.com/refraction-networking/utls v1.5.4
Expand Down Expand Up @@ -68,7 +69,6 @@ require (
github.com/pion/logging v0.2.2 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/sctp v1.8.7 // indirect
github.com/pion/transport/v2 v2.2.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
Expand Down
1 change: 1 addition & 0 deletions infra/conf/v5cfg/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v5cfg

import (
"context"

"github.com/golang/protobuf/proto"

core "github.com/v2fly/v2ray-core/v5"
Expand Down
170 changes: 147 additions & 23 deletions proxy/shadowsocks2022/client_session.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of contention on sessionMap. The lock is taken once per packet on the encode path, and twice per packet on the decode path.

Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ import (
"github.com/v2fly/v2ray-core/v5/common/buf"
"github.com/v2fly/v2ray-core/v5/common/net"
"github.com/v2fly/v2ray-core/v5/transport/internet"

"github.com/pion/transport/v2/replaydetector"
)

func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
session := &ClientUDPSession{
locker: &sync.Mutex{},
locker: &sync.RWMutex{},
conn: conn,
packetProcessor: packetProcessor,
sessionMap: make(map[string]*ClientUDPSessionConn),
sessionMapAlias: make(map[string]string),
}
session.ctx, session.finish = context.WithCancel(ctx)

Expand All @@ -27,16 +30,87 @@ func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetPro
}

type ClientUDPSession struct {
locker *sync.Mutex
locker *sync.RWMutex

conn io.ReadWriteCloser
packetProcessor UDPClientPacketProcessor
sessionMap map[string]*ClientUDPSessionConn

sessionMapAlias map[string]string

ctx context.Context
finish func()
}

func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()

state, ok := c.sessionMap[sessionID]
if !ok {
return nil
}
return state.cachedProcessorState
}

func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()

clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return nil
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return nil
}

if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
return nil
} else {
return serverState.cachedRecvProcessorState
}
}

func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
state, ok := c.sessionMapAlias[serverSessionID]
if !ok {
return ""
}
return state
}

func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()

state, ok := c.sessionMap[sessionID]
if !ok {
return
}
state.cachedProcessorState = cache
}

func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()

clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return
}

if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
serverState.cachedRecvProcessorState = cache
return
}
}

func (c *ClientUDPSession) Close() error {
c.finish()
return c.conn.Close()
Expand All @@ -45,7 +119,7 @@ func (c *ClientUDPSession) Close() error {
func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
buffer := buf.New()
defer buffer.Release()
err := c.packetProcessor.EncodeUDPRequest(request, buffer)
err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
if request.Payload != nil {
request.Payload.Release()
}
Expand All @@ -69,7 +143,7 @@ func (c *ClientUDPSession) KeepReading() {
return
}
if n != 0 {
err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp)
err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
if err != nil {
newError("unable to decode udp response").Base(err).WriteToLog()
continue
Expand All @@ -78,13 +152,14 @@ func (c *ClientUDPSession) KeepReading() {
{
timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 {
newError("udp packet timestamp difference too large, packet discarded").WriteToLog()
newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
continue
}
}

c.locker.Lock()
c.locker.RLock()
session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
c.locker.RUnlock()
if ok {
select {
case session.readChan <- udpResp:
Expand All @@ -93,7 +168,6 @@ func (c *ClientUDPSession) KeepReading() {
} else {
newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
}
c.locker.Unlock()
}
}
}
Expand All @@ -108,32 +182,47 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
connctx, connfinish := context.WithCancel(c.ctx)

sessionConn := &ClientUDPSessionConn{
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16),
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16),
database64128 marked this conversation as resolved.
Show resolved Hide resolved
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
}
c.locker.Lock()
c.sessionMap[sessionConn.sessionID] = sessionConn
c.locker.Unlock()
return sessionConn, nil
}

type ClientUDPSessionServerTracker struct {
cachedRecvProcessorState UDPClientPacketProcessorCachedState
rxReplayDetector replaydetector.ReplayDetector
lastSeen time.Time
}

type ClientUDPSessionConn struct {
sessionID string
readChan chan *UDPResponse
parent *ClientUDPSession

nextWritePacketID uint64
nextWritePacketID uint64
trackedServerSessionID map[string]*ClientUDPSessionServerTracker

cachedProcessorState UDPClientPacketProcessorCachedState

ctx context.Context
finish func()
}

func (c *ClientUDPSessionConn) Close() error {
c.parent.locker.Lock()
delete(c.parent.sessionMap, c.sessionID)
for k := range c.trackedServerSessionID {
delete(c.parent.sessionMapAlias, k)
}
c.parent.locker.Unlock()
c.finish()
return nil
}
Expand All @@ -160,13 +249,48 @@ func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err er
}

func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case <-c.ctx.Done():
return 0, nil, io.EOF
case resp := <-c.readChan:
n = copy(p, resp.Payload.Bytes())
resp.Payload.Release()
addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
}
return
for {
select {
case <-c.ctx.Done():
return 0, nil, io.EOF
case resp := <-c.readChan:
n = copy(p, resp.Payload.Bytes())
resp.Payload.Release()

var trackedState *ClientUDPSessionServerTracker
if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
expiredServerSessionID := make([]string, 0)
for key, value := range c.trackedServerSessionID {
if time.Since(value.lastSeen) > 125*time.Second {
expiredServerSessionID = append(expiredServerSessionID, key)
}
}
for _, key := range expiredServerSessionID {
delete(c.trackedServerSessionID, key)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to delete k-v pairs from a map while iterating over it. There's no need to cache the keys to delete in a separate slice.

The replay window during a server session change is only 60 seconds. Why does it cache the state for 125 seconds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have adjusted this value to 65 seconds and removed mark and remove for Server Session ID.


state := &ClientUDPSessionServerTracker{
rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
}
c.trackedServerSessionID[string(resp.SessionID[:])] = state
c.parent.locker.RLock()
c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
c.parent.locker.RUnlock()
trackedState = state
} else {
trackedState = trackedStateReceived
}

if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
accept()
} else {
newError("misbehaving server: replayed packet").Base(err).WriteToLog()
continue
}
trackedState.lastSeen = time.Now()

addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
}
return n, addr, nil
}
}
4 changes: 2 additions & 2 deletions proxy/shadowsocks2022/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (t *TCPRequest) EncodeTCPRequestHeader(effectivePsk []byte,
paddingLength := TCPMinPaddingLength
if initialPayload == nil {
initialPayload = []byte{}
paddingLength += rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED
paddingLength += 1 + rand.Intn(TCPMaxPaddingLength) // TODO INSECURE RANDOM USED
}

variableLengthHeader := &TCPRequestHeader3VariableLength{
Expand Down Expand Up @@ -206,7 +206,7 @@ func (t *TCPRequest) DecodeTCPResponseHeader(effectivePsk []byte, in io.Reader)
}
timeDifference := int64(fixedLengthHeader.Timestamp) - time.Now().Unix()
if timeDifference < -30 || timeDifference > 30 {
return newError("timestamp is too far away")
return newError("timestamp is too far away, timeDifference = ", timeDifference)
}

t.s2cSaltAssert = fixedLengthHeader.RequestSalt
Expand Down
13 changes: 11 additions & 2 deletions proxy/shadowsocks2022/ss2022.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,18 @@ const (
UDPHeaderTypeServerToClientStream = byte(0x01)
)

type UDPClientPacketProcessorCachedStateContainer interface {
GetCachedState(sessionID string) UDPClientPacketProcessorCachedState
PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState)
GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState
PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState)
}

type UDPClientPacketProcessorCachedState interface{}

// UDPClientPacketProcessor
// Caller retain and receive all ownership of the buffer
type UDPClientPacketProcessor interface {
EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error
DecodeUDPResp(input []byte, resp *UDPResponse) error
EncodeUDPRequest(request *UDPRequest, out *buf.Buffer, cache UDPClientPacketProcessorCachedStateContainer) error
DecodeUDPResp(input []byte, resp *UDPResponse, cache UDPClientPacketProcessorCachedStateContainer) error
}
41 changes: 37 additions & 4 deletions proxy/shadowsocks2022/udp_aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ type respHeader struct {
Padding []byte
}

func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer) error {
type cachedUDPState struct {
sessionAEAD cipher.AEAD
sessionRecvAEAD cipher.AEAD
}

func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out *buf.Buffer,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderStruct := separateHeader{PacketID: request.PacketID, SessionID: request.SessionID}
separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release()
Expand Down Expand Up @@ -102,14 +109,28 @@ func (p *AESUDPClientPacketProcessor) EncodeUDPRequest(request *UDPRequest, out
}
}
{
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBufferBytes[0:8])
cacheKey := string(separateHeaderBufferBytes[0:8])
receivedCacheInterface := cache.GetCachedState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}
if cachedState.sessionAEAD == nil {
cachedState.sessionAEAD = p.mainPacketAEAD(separateHeaderBufferBytes[0:8])
cache.PutCachedState(cacheKey, cachedState)
}

mainPacketAEADMaterialized := cachedState.sessionAEAD

encryptedDest := out.Extend(int32(mainPacketAEADMaterialized.Overhead()) + requestBodyBuffer.Len())
mainPacketAEADMaterialized.Seal(encryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], requestBodyBuffer.Bytes(), nil)
}
return nil
}

func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse) error {
func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPResponse,
cache UDPClientPacketProcessorCachedStateContainer,
) error {
separateHeaderBuffer := buf.New()
defer separateHeaderBuffer.Release()
{
Expand All @@ -126,7 +147,19 @@ func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPRespo
resp.PacketID = separateHeaderStruct.PacketID
resp.SessionID = separateHeaderStruct.SessionID
{
mainPacketAEADMaterialized := p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cacheKey := string(separateHeaderBuffer.Bytes()[0:8])
receivedCacheInterface := cache.GetCachedServerState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}

if cachedState.sessionRecvAEAD == nil {
cachedState.sessionRecvAEAD = p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cache.PutCachedServerState(cacheKey, cachedState)
}

mainPacketAEADMaterialized := cachedState.sessionRecvAEAD
decryptedDestBuffer := buf.New()
decryptedDest := decryptedDestBuffer.Extend(int32(len(input)) - 16 - int32(mainPacketAEADMaterialized.Overhead()))
_, err := mainPacketAEADMaterialized.Open(decryptedDest[:0], separateHeaderBuffer.Bytes()[4:16], input[16:], nil)
Expand Down
Loading