Socks5 UDP with standard RFC behavior

This commit is contained in:
Fangliding
2026-05-17 03:00:49 +08:00
parent 1bdb488c9e
commit 8503230891
4 changed files with 73 additions and 68 deletions
+24 -13
View File
@@ -137,13 +137,13 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer)
return "", nil
}
func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, *TempUDPConn, error) {
var (
username string
err error
)
if username, err = s.auth5(nMethod, reader, writer); err != nil {
return nil, err
return nil, nil, err
}
var cmd byte
@@ -151,7 +151,7 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
buffer := buf.StackNew()
if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
buffer.Release()
return nil, errors.New("failed to read request").Base(err)
return nil, nil, errors.New("failed to read request").Base(err)
}
cmd = buffer.Byte(1)
buffer.Release()
@@ -168,28 +168,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
case cmdUDPAssociate:
if !s.config.UdpEnabled {
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, errors.New("UDP is not enabled.")
return nil, nil, errors.New("UDP is not enabled.")
}
request.Command = protocol.RequestCommandUDP
case cmdTCPBind:
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, errors.New("TCP bind is not supported.")
return nil, nil, errors.New("TCP bind is not supported.")
default:
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
return nil, errors.New("unknown command ", cmd)
return nil, nil, errors.New("unknown command ", cmd)
}
request.Version = socks5Version
addr, port, err := addrParser.ReadAddressPort(nil, reader)
if err != nil {
return nil, errors.New("failed to read address").Base(err)
return nil, nil, errors.New("failed to read address").Base(err)
}
request.Address = addr
request.Port = port
responseAddress := s.address
responsePort := s.port
var tempUDPConn *TempUDPConn
//nolint:gocritic // Use if else chain for clarity
if request.Command == protocol.RequestCommandUDP {
if s.config.Address != nil {
@@ -199,20 +200,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
// Use conn.LocalAddr() IP as remote address in the response by default
responseAddress = s.localAddress
}
udpHub, err := net.ListenUDP("udp", &net.UDPAddr{IP: responseAddress.IP()})
if err != nil {
return nil, nil, errors.New("failed to create UDP listener").Base(err)
}
responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port)
tempUDPConn = &TempUDPConn{
UDPConn: udpHub,
}
}
if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
return nil, err
common.CloseIfExists(tempUDPConn)
return nil, nil, err
}
return request, nil
return request, tempUDPConn, nil
}
// Handshake performs a Socks4/4a/5 handshake.
func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, *TempUDPConn, error) {
buffer := buf.StackNew()
if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
buffer.Release()
return nil, errors.New("insufficient header").Base(err)
return nil, nil, errors.New("insufficient header").Base(err)
}
version := buffer.Byte(0)
@@ -221,11 +231,12 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
switch version {
case socks4Version:
return s.handshake4(cmd, reader, writer)
header, err := s.handshake4(cmd, reader, writer)
return header, nil, err
case socks5Version:
return s.handshake5(cmd, reader, writer)
default:
return nil, errors.New("unknown Socks version: ", version)
return nil, nil, errors.New("unknown Socks version: ", version)
}
}
+13 -24
View File
@@ -29,7 +29,6 @@ type Server struct {
config *ServerConfig
policyManager policy.Manager
cone bool
udpFilter *UDPFilter
httpServer *http.Server
}
@@ -46,7 +45,6 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
}
if config.AuthType == AuthType_PASSWORD {
httpConfig.Accounts = config.Accounts
s.udpFilter = new(UDPFilter) // We only use this when auth is enabled
}
s.httpServer, _ = http.NewServer(ctx, httpConfig)
return s, nil
@@ -60,11 +58,7 @@ func (s *Server) policy() policy.Session {
// Network implements proxy.Inbound.
func (s *Server) Network() []net.Network {
list := []net.Network{net.Network_TCP}
if s.config.UdpEnabled {
list = append(list, net.Network_UDP)
}
return list
return []net.Network{net.Network_TCP}
}
// Process implements proxy.Inbound.
@@ -94,8 +88,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
return s.httpServer.ProcessWithFirstbyte(ctx, network, conn, dispatcher, firstbyte...)
}
return s.processTCP(ctx, conn, dispatcher, firstbyte)
case net.Network_UDP:
return s.handleUDPPayload(ctx, conn, dispatcher)
default:
return errors.New("unknown network: ", network)
}
@@ -126,7 +118,8 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
Reader: buf.NewReader(conn),
Buffer: buf.MultiBuffer{buf.FromBytes(firstbyte)},
}
request, err := svrSession.Handshake(reader, conn)
request, tempUDPConn, err := svrSession.Handshake(reader, conn)
defer common.CloseIfExists(tempUDPConn)
if err != nil {
if inbound.Source.IsValid() {
log.Record(&log.AccessMessage{
@@ -170,26 +163,22 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
}
if request.Command == protocol.RequestCommandUDP {
if s.udpFilter != nil {
s.udpFilter.Add(conn.RemoteAddr())
if tempUDPConn == nil {
return errors.New("UDP associate with listen port failed")
}
return s.handleUDP(conn)
ctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
errCh <- s.handleUDPPayload(ctx, tempUDPConn, dispatcher)
}()
io.Copy(buf.DiscardBytes, conn)
cancel()
return <-errCh
}
return nil
}
func (*Server) handleUDP(c io.Reader) error {
// The TCP connection closes after this method returns. We need to wait until
// the client closes it.
return common.Error2(io.Copy(buf.DiscardBytes, c))
}
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
return nil
}
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
payload := packet.Payload
errors.LogDebug(ctx, "writing back UDP response with ", payload.Len(), " bytes")
+36
View File
@@ -0,0 +1,36 @@
package socks
import (
"net"
sync "sync"
"github.com/xtls/xray-core/common/errors"
)
type TempUDPConn struct {
*net.UDPConn
once sync.Once
remote net.Addr
}
func (c *TempUDPConn) Read(b []byte) (n int, err error) {
n, addr, err := c.ReadFrom(b)
if err != nil {
return 0, err
}
c.once.Do(func() {
c.remote = addr
})
return n, nil
}
func (c *TempUDPConn) Write(b []byte) (n int, err error) {
if c.remote == nil {
return 0, errors.New("remote address not determined yet")
}
return c.UDPConn.WriteTo(b, c.remote)
}
func (c *TempUDPConn) RemoteAddr() net.Addr {
return c.remote
}
-31
View File
@@ -1,31 +0,0 @@
package socks
import (
"net"
"sync"
)
/*
In the sock implementation of * ray, UDP authentication is flawed and can be bypassed.
Tracking a UDP connection may be a bit troublesome.
Here is a simple solution.
We create a filter, add remote IP to the pool when it try to establish a UDP connection with auth.
And drop UDP packets from unauthorized IP.
After discussion, we believe it is not necessary to add a timeout mechanism to this filter.
*/
type UDPFilter struct {
ips sync.Map
}
func (f *UDPFilter) Add(addr net.Addr) bool {
ip, _, _ := net.SplitHostPort(addr.String())
f.ips.Store(ip, true)
return true
}
func (f *UDPFilter) Check(addr net.Addr) bool {
ip, _, _ := net.SplitHostPort(addr.String())
_, ok := f.ips.Load(ip)
return ok
}