From 8503230891d7f8d4abed4d5e2c26d822fffd9f84 Mon Sep 17 00:00:00 2001 From: Fangliding Date: Sun, 17 May 2026 03:00:49 +0800 Subject: [PATCH] Socks5 UDP with standard RFC behavior --- proxy/socks/protocol.go | 37 ++++++++++++++++++++++------------ proxy/socks/server.go | 37 ++++++++++++---------------------- proxy/socks/temp_udp_listen.go | 36 +++++++++++++++++++++++++++++++++ proxy/socks/udpfilter.go | 31 ---------------------------- 4 files changed, 73 insertions(+), 68 deletions(-) create mode 100644 proxy/socks/temp_udp_listen.go delete mode 100644 proxy/socks/udpfilter.go diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 76ceb47c..a5233508 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -137,13 +137,13 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer) return "", nil } -func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { +func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, *TempUDPConn, error) { var ( username string err error ) if username, err = s.auth5(nMethod, reader, writer); err != nil { - return nil, err + return nil, nil, err } var cmd byte @@ -151,7 +151,7 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri buffer := buf.StackNew() if _, err := buffer.ReadFullFrom(reader, 3); err != nil { buffer.Release() - return nil, errors.New("failed to read request").Base(err) + return nil, nil, errors.New("failed to read request").Base(err) } cmd = buffer.Byte(1) buffer.Release() @@ -168,28 +168,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri case cmdUDPAssociate: if !s.config.UdpEnabled { writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("UDP is not enabled.") + return nil, nil, errors.New("UDP is not enabled.") } request.Command = protocol.RequestCommandUDP case cmdTCPBind: writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("TCP bind is not supported.") + return nil, nil, errors.New("TCP bind is not supported.") default: writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("unknown command ", cmd) + return nil, nil, errors.New("unknown command ", cmd) } request.Version = socks5Version addr, port, err := addrParser.ReadAddressPort(nil, reader) if err != nil { - return nil, errors.New("failed to read address").Base(err) + return nil, nil, errors.New("failed to read address").Base(err) } request.Address = addr request.Port = port responseAddress := s.address responsePort := s.port + var tempUDPConn *TempUDPConn //nolint:gocritic // Use if else chain for clarity if request.Command == protocol.RequestCommandUDP { if s.config.Address != nil { @@ -199,20 +200,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri // Use conn.LocalAddr() IP as remote address in the response by default responseAddress = s.localAddress } + udpHub, err := net.ListenUDP("udp", &net.UDPAddr{IP: responseAddress.IP()}) + if err != nil { + return nil, nil, errors.New("failed to create UDP listener").Base(err) + } + responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port) + tempUDPConn = &TempUDPConn{ + UDPConn: udpHub, + } } if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil { - return nil, err + common.CloseIfExists(tempUDPConn) + return nil, nil, err } - return request, nil + return request, tempUDPConn, nil } // Handshake performs a Socks4/4a/5 handshake. -func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { +func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, *TempUDPConn, error) { buffer := buf.StackNew() if _, err := buffer.ReadFullFrom(reader, 2); err != nil { buffer.Release() - return nil, errors.New("insufficient header").Base(err) + return nil, nil, errors.New("insufficient header").Base(err) } version := buffer.Byte(0) @@ -221,11 +231,12 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol switch version { case socks4Version: - return s.handshake4(cmd, reader, writer) + header, err := s.handshake4(cmd, reader, writer) + return header, nil, err case socks5Version: return s.handshake5(cmd, reader, writer) default: - return nil, errors.New("unknown Socks version: ", version) + return nil, nil, errors.New("unknown Socks version: ", version) } } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 478410f3..8c4dc694 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -29,7 +29,6 @@ type Server struct { config *ServerConfig policyManager policy.Manager cone bool - udpFilter *UDPFilter httpServer *http.Server } @@ -46,7 +45,6 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } if config.AuthType == AuthType_PASSWORD { httpConfig.Accounts = config.Accounts - s.udpFilter = new(UDPFilter) // We only use this when auth is enabled } s.httpServer, _ = http.NewServer(ctx, httpConfig) return s, nil @@ -60,11 +58,7 @@ func (s *Server) policy() policy.Session { // Network implements proxy.Inbound. func (s *Server) Network() []net.Network { - list := []net.Network{net.Network_TCP} - if s.config.UdpEnabled { - list = append(list, net.Network_UDP) - } - return list + return []net.Network{net.Network_TCP} } // Process implements proxy.Inbound. @@ -94,8 +88,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con return s.httpServer.ProcessWithFirstbyte(ctx, network, conn, dispatcher, firstbyte...) } return s.processTCP(ctx, conn, dispatcher, firstbyte) - case net.Network_UDP: - return s.handleUDPPayload(ctx, conn, dispatcher) default: return errors.New("unknown network: ", network) } @@ -126,7 +118,8 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche Reader: buf.NewReader(conn), Buffer: buf.MultiBuffer{buf.FromBytes(firstbyte)}, } - request, err := svrSession.Handshake(reader, conn) + request, tempUDPConn, err := svrSession.Handshake(reader, conn) + defer common.CloseIfExists(tempUDPConn) if err != nil { if inbound.Source.IsValid() { log.Record(&log.AccessMessage{ @@ -170,26 +163,22 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche } if request.Command == protocol.RequestCommandUDP { - if s.udpFilter != nil { - s.udpFilter.Add(conn.RemoteAddr()) + if tempUDPConn == nil { + return errors.New("UDP associate with listen port failed") } - return s.handleUDP(conn) + ctx, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + errCh <- s.handleUDPPayload(ctx, tempUDPConn, dispatcher) + }() + io.Copy(buf.DiscardBytes, conn) + cancel() + return <-errCh } - return nil } -func (*Server) handleUDP(c io.Reader) error { - // The TCP connection closes after this method returns. We need to wait until - // the client closes it. - return common.Error2(io.Copy(buf.DiscardBytes, c)) -} - func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { - if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) { - errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String()) - return nil - } udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { payload := packet.Payload errors.LogDebug(ctx, "writing back UDP response with ", payload.Len(), " bytes") diff --git a/proxy/socks/temp_udp_listen.go b/proxy/socks/temp_udp_listen.go new file mode 100644 index 00000000..5fec57f3 --- /dev/null +++ b/proxy/socks/temp_udp_listen.go @@ -0,0 +1,36 @@ +package socks + +import ( + "net" + sync "sync" + + "github.com/xtls/xray-core/common/errors" +) + +type TempUDPConn struct { + *net.UDPConn + once sync.Once + remote net.Addr +} + +func (c *TempUDPConn) Read(b []byte) (n int, err error) { + n, addr, err := c.ReadFrom(b) + if err != nil { + return 0, err + } + c.once.Do(func() { + c.remote = addr + }) + return n, nil +} + +func (c *TempUDPConn) Write(b []byte) (n int, err error) { + if c.remote == nil { + return 0, errors.New("remote address not determined yet") + } + return c.UDPConn.WriteTo(b, c.remote) +} + +func (c *TempUDPConn) RemoteAddr() net.Addr { + return c.remote +} diff --git a/proxy/socks/udpfilter.go b/proxy/socks/udpfilter.go deleted file mode 100644 index 9ae3e697..00000000 --- a/proxy/socks/udpfilter.go +++ /dev/null @@ -1,31 +0,0 @@ -package socks - -import ( - "net" - "sync" -) - -/* -In the sock implementation of * ray, UDP authentication is flawed and can be bypassed. -Tracking a UDP connection may be a bit troublesome. -Here is a simple solution. -We create a filter, add remote IP to the pool when it try to establish a UDP connection with auth. -And drop UDP packets from unauthorized IP. -After discussion, we believe it is not necessary to add a timeout mechanism to this filter. -*/ - -type UDPFilter struct { - ips sync.Map -} - -func (f *UDPFilter) Add(addr net.Addr) bool { - ip, _, _ := net.SplitHostPort(addr.String()) - f.ips.Store(ip, true) - return true -} - -func (f *UDPFilter) Check(addr net.Addr) bool { - ip, _, _ := net.SplitHostPort(addr.String()) - _, ok := f.ips.Load(ip) - return ok -}