mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-06-14 17:13:17 +00:00
little refactor
This commit is contained in:
+7
-13
@@ -208,20 +208,14 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer net.Co
|
||||
return nil, nil, errors.New("failed to create UDP listener").Base(err)
|
||||
}
|
||||
responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port)
|
||||
expectedRemoteIP, _, _ := net.SplitHostPort(writer.RemoteAddr().String())
|
||||
tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemoteIP)
|
||||
if !request.Address.IP().IsUnspecified() {
|
||||
// only specified an IP without port
|
||||
if request.Port == 0 {
|
||||
tempUDPConn.ExpectedRemoteIP = request.Address.String()
|
||||
} else { // specified both IP and port
|
||||
var udpRemote gonet.Addr = &gonet.UDPAddr{
|
||||
IP: request.Address.IP(),
|
||||
Port: int(request.Port),
|
||||
}
|
||||
tempUDPConn.remote.Store(&udpRemote)
|
||||
}
|
||||
expectedRemote := &gonet.UDPAddr{}
|
||||
if request.Address.IP().IsUnspecified() {
|
||||
expectedRemote.IP = writer.RemoteAddr().(*net.TCPAddr).IP // unix?
|
||||
} else {
|
||||
expectedRemote.IP = request.Address.IP() // panic?
|
||||
expectedRemote.Port = int(request.Port) // 0 is allowed
|
||||
}
|
||||
tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemote)
|
||||
}
|
||||
if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
|
||||
common.CloseIfExists(tempUDPConn)
|
||||
|
||||
@@ -6,71 +6,66 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
)
|
||||
|
||||
func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, remoteIP string) *TempUDPConn {
|
||||
return &TempUDPConn{
|
||||
PacketConn: udpConn,
|
||||
AssociateTCPConn: tcpConn,
|
||||
ExpectedRemoteIP: remoteIP,
|
||||
func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, expectedRemote *net.UDPAddr) *TempUDPConn {
|
||||
t := &TempUDPConn{
|
||||
PacketConn: udpConn,
|
||||
AssociatedTCPConn: tcpConn,
|
||||
}
|
||||
t.ExpectedRemote.Store(expectedRemote)
|
||||
return t
|
||||
}
|
||||
|
||||
// TempUDPConn wait for the first packet to determine the remote address
|
||||
// SetTimeout MUST be called before any read/write operation
|
||||
type TempUDPConn struct {
|
||||
net.PacketConn
|
||||
AssociateTCPConn net.Conn
|
||||
ExpectedRemoteIP string
|
||||
|
||||
timer *signal.ActivityTimer
|
||||
remote atomic.Pointer[net.Addr]
|
||||
AssociatedTCPConn net.Conn
|
||||
ExpectedRemote atomic.Pointer[net.UDPAddr]
|
||||
Timer *signal.ActivityTimer
|
||||
}
|
||||
|
||||
func (c *TempUDPConn) Read(b []byte) (n int, err error) {
|
||||
c.timer.Update()
|
||||
c.Timer.Update()
|
||||
var remote net.Addr
|
||||
for {
|
||||
n, remote, err = c.PacketConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if load := c.remote.Load(); load == nil {
|
||||
if remoteIP, _, _ := net.SplitHostPort(remote.String()); remoteIP == c.ExpectedRemoteIP {
|
||||
c.remote.Store(&remote)
|
||||
remote := remote.(*net.UDPAddr)
|
||||
expected := c.ExpectedRemote.Load()
|
||||
if remote.IP.Equal(expected.IP) {
|
||||
if remote.Port == expected.Port {
|
||||
return
|
||||
}
|
||||
if expected.Port == 0 {
|
||||
c.ExpectedRemote.Store(remote)
|
||||
return
|
||||
}
|
||||
} else if remote.String() == (*load).String() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TempUDPConn) Write(b []byte) (n int, err error) {
|
||||
c.timer.Update()
|
||||
if c.remote.Load() == nil {
|
||||
return 0, errors.New("remote address not determined yet")
|
||||
}
|
||||
return c.PacketConn.WriteTo(b, *c.remote.Load())
|
||||
c.Timer.Update()
|
||||
return c.PacketConn.WriteTo(b, c.ExpectedRemote.Load())
|
||||
}
|
||||
|
||||
func (c *TempUDPConn) RemoteAddr() net.Addr {
|
||||
if c.remote.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
return *c.remote.Load()
|
||||
return c.ExpectedRemote.Load()
|
||||
}
|
||||
|
||||
func (c *TempUDPConn) SetTimeout(d time.Duration) {
|
||||
c.timer = signal.CancelAfterInactivity(context.Background(), func() {
|
||||
c.Timer = signal.CancelAfterInactivity(context.Background(), func() {
|
||||
c.Close()
|
||||
}, d)
|
||||
}
|
||||
|
||||
func (c *TempUDPConn) Close() error {
|
||||
c.timer.SetTimeout(0)
|
||||
c.AssociateTCPConn.Close()
|
||||
c.Timer.SetTimeout(0)
|
||||
c.AssociatedTCPConn.Close()
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user