little refactor

This commit is contained in:
RPRX
2026-05-28 13:12:16 +00:00
committed by GitHub
parent dfad7c05bd
commit bdc39bc155
2 changed files with 31 additions and 42 deletions
+7 -13
View File
@@ -208,20 +208,14 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer net.Co
return nil, nil, errors.New("failed to create UDP listener").Base(err)
}
responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port)
expectedRemoteIP, _, _ := net.SplitHostPort(writer.RemoteAddr().String())
tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemoteIP)
if !request.Address.IP().IsUnspecified() {
// only specified an IP without port
if request.Port == 0 {
tempUDPConn.ExpectedRemoteIP = request.Address.String()
} else { // specified both IP and port
var udpRemote gonet.Addr = &gonet.UDPAddr{
IP: request.Address.IP(),
Port: int(request.Port),
}
tempUDPConn.remote.Store(&udpRemote)
}
expectedRemote := &gonet.UDPAddr{}
if request.Address.IP().IsUnspecified() {
expectedRemote.IP = writer.RemoteAddr().(*net.TCPAddr).IP // unix?
} else {
expectedRemote.IP = request.Address.IP() // panic?
expectedRemote.Port = int(request.Port) // 0 is allowed
}
tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemote)
}
if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
common.CloseIfExists(tempUDPConn)
+24 -29
View File
@@ -6,71 +6,66 @@ import (
"sync/atomic"
"time"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/signal"
)
func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, remoteIP string) *TempUDPConn {
return &TempUDPConn{
PacketConn: udpConn,
AssociateTCPConn: tcpConn,
ExpectedRemoteIP: remoteIP,
func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, expectedRemote *net.UDPAddr) *TempUDPConn {
t := &TempUDPConn{
PacketConn: udpConn,
AssociatedTCPConn: tcpConn,
}
t.ExpectedRemote.Store(expectedRemote)
return t
}
// TempUDPConn wait for the first packet to determine the remote address
// SetTimeout MUST be called before any read/write operation
type TempUDPConn struct {
net.PacketConn
AssociateTCPConn net.Conn
ExpectedRemoteIP string
timer *signal.ActivityTimer
remote atomic.Pointer[net.Addr]
AssociatedTCPConn net.Conn
ExpectedRemote atomic.Pointer[net.UDPAddr]
Timer *signal.ActivityTimer
}
func (c *TempUDPConn) Read(b []byte) (n int, err error) {
c.timer.Update()
c.Timer.Update()
var remote net.Addr
for {
n, remote, err = c.PacketConn.ReadFrom(b)
if err != nil {
return
}
if load := c.remote.Load(); load == nil {
if remoteIP, _, _ := net.SplitHostPort(remote.String()); remoteIP == c.ExpectedRemoteIP {
c.remote.Store(&remote)
remote := remote.(*net.UDPAddr)
expected := c.ExpectedRemote.Load()
if remote.IP.Equal(expected.IP) {
if remote.Port == expected.Port {
return
}
if expected.Port == 0 {
c.ExpectedRemote.Store(remote)
return
}
} else if remote.String() == (*load).String() {
return
}
}
}
func (c *TempUDPConn) Write(b []byte) (n int, err error) {
c.timer.Update()
if c.remote.Load() == nil {
return 0, errors.New("remote address not determined yet")
}
return c.PacketConn.WriteTo(b, *c.remote.Load())
c.Timer.Update()
return c.PacketConn.WriteTo(b, c.ExpectedRemote.Load())
}
func (c *TempUDPConn) RemoteAddr() net.Addr {
if c.remote.Load() == nil {
return nil
}
return *c.remote.Load()
return c.ExpectedRemote.Load()
}
func (c *TempUDPConn) SetTimeout(d time.Duration) {
c.timer = signal.CancelAfterInactivity(context.Background(), func() {
c.Timer = signal.CancelAfterInactivity(context.Background(), func() {
c.Close()
}, d)
}
func (c *TempUDPConn) Close() error {
c.timer.SetTimeout(0)
c.AssociateTCPConn.Close()
c.Timer.SetTimeout(0)
c.AssociatedTCPConn.Close()
return c.PacketConn.Close()
}