From bdc39bc155e0fbca282d1d09790daca271ed405e Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Thu, 28 May 2026 13:12:16 +0000 Subject: [PATCH] little refactor --- proxy/socks/protocol.go | 20 +++++-------- proxy/socks/temp_udp_listen.go | 53 +++++++++++++++------------------- 2 files changed, 31 insertions(+), 42 deletions(-) diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 19efa444..bf4f61a5 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -208,20 +208,14 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer net.Co return nil, nil, errors.New("failed to create UDP listener").Base(err) } responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port) - expectedRemoteIP, _, _ := net.SplitHostPort(writer.RemoteAddr().String()) - tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemoteIP) - if !request.Address.IP().IsUnspecified() { - // only specified an IP without port - if request.Port == 0 { - tempUDPConn.ExpectedRemoteIP = request.Address.String() - } else { // specified both IP and port - var udpRemote gonet.Addr = &gonet.UDPAddr{ - IP: request.Address.IP(), - Port: int(request.Port), - } - tempUDPConn.remote.Store(&udpRemote) - } + expectedRemote := &gonet.UDPAddr{} + if request.Address.IP().IsUnspecified() { + expectedRemote.IP = writer.RemoteAddr().(*net.TCPAddr).IP // unix? + } else { + expectedRemote.IP = request.Address.IP() // panic? + expectedRemote.Port = int(request.Port) // 0 is allowed } + tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemote) } if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil { common.CloseIfExists(tempUDPConn) diff --git a/proxy/socks/temp_udp_listen.go b/proxy/socks/temp_udp_listen.go index 84d62733..06474d95 100644 --- a/proxy/socks/temp_udp_listen.go +++ b/proxy/socks/temp_udp_listen.go @@ -6,71 +6,66 @@ import ( "sync/atomic" "time" - "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/signal" ) -func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, remoteIP string) *TempUDPConn { - return &TempUDPConn{ - PacketConn: udpConn, - AssociateTCPConn: tcpConn, - ExpectedRemoteIP: remoteIP, +func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, expectedRemote *net.UDPAddr) *TempUDPConn { + t := &TempUDPConn{ + PacketConn: udpConn, + AssociatedTCPConn: tcpConn, } + t.ExpectedRemote.Store(expectedRemote) + return t } // TempUDPConn wait for the first packet to determine the remote address // SetTimeout MUST be called before any read/write operation type TempUDPConn struct { net.PacketConn - AssociateTCPConn net.Conn - ExpectedRemoteIP string - - timer *signal.ActivityTimer - remote atomic.Pointer[net.Addr] + AssociatedTCPConn net.Conn + ExpectedRemote atomic.Pointer[net.UDPAddr] + Timer *signal.ActivityTimer } func (c *TempUDPConn) Read(b []byte) (n int, err error) { - c.timer.Update() + c.Timer.Update() var remote net.Addr for { n, remote, err = c.PacketConn.ReadFrom(b) if err != nil { return } - if load := c.remote.Load(); load == nil { - if remoteIP, _, _ := net.SplitHostPort(remote.String()); remoteIP == c.ExpectedRemoteIP { - c.remote.Store(&remote) + remote := remote.(*net.UDPAddr) + expected := c.ExpectedRemote.Load() + if remote.IP.Equal(expected.IP) { + if remote.Port == expected.Port { + return + } + if expected.Port == 0 { + c.ExpectedRemote.Store(remote) return } - } else if remote.String() == (*load).String() { - return } } } func (c *TempUDPConn) Write(b []byte) (n int, err error) { - c.timer.Update() - if c.remote.Load() == nil { - return 0, errors.New("remote address not determined yet") - } - return c.PacketConn.WriteTo(b, *c.remote.Load()) + c.Timer.Update() + return c.PacketConn.WriteTo(b, c.ExpectedRemote.Load()) } func (c *TempUDPConn) RemoteAddr() net.Addr { - if c.remote.Load() == nil { - return nil - } - return *c.remote.Load() + return c.ExpectedRemote.Load() } func (c *TempUDPConn) SetTimeout(d time.Duration) { - c.timer = signal.CancelAfterInactivity(context.Background(), func() { + c.Timer = signal.CancelAfterInactivity(context.Background(), func() { c.Close() }, d) } func (c *TempUDPConn) Close() error { - c.timer.SetTimeout(0) - c.AssociateTCPConn.Close() + c.Timer.SetTimeout(0) + c.AssociatedTCPConn.Close() return c.PacketConn.Close() }