diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 76ceb47c..bf4f61a5 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -1,14 +1,17 @@ package socks import ( + "context" "encoding/binary" "io" + gonet "net" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/transport/internet" ) const ( @@ -137,13 +140,13 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer) return "", nil } -func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { +func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer net.Conn) (*protocol.RequestHeader, *TempUDPConn, error) { var ( username string err error ) if username, err = s.auth5(nMethod, reader, writer); err != nil { - return nil, err + return nil, nil, err } var cmd byte @@ -151,7 +154,7 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri buffer := buf.StackNew() if _, err := buffer.ReadFullFrom(reader, 3); err != nil { buffer.Release() - return nil, errors.New("failed to read request").Base(err) + return nil, nil, errors.New("failed to read request").Base(err) } cmd = buffer.Byte(1) buffer.Release() @@ -168,28 +171,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri case cmdUDPAssociate: if !s.config.UdpEnabled { writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("UDP is not enabled.") + return nil, nil, errors.New("UDP is not enabled.") } request.Command = protocol.RequestCommandUDP case cmdTCPBind: writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("TCP bind is not supported.") + return nil, nil, errors.New("TCP bind is not supported.") default: writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0)) - return nil, errors.New("unknown command ", cmd) + return nil, nil, errors.New("unknown command ", cmd) } request.Version = socks5Version addr, port, err := addrParser.ReadAddressPort(nil, reader) if err != nil { - return nil, errors.New("failed to read address").Base(err) + return nil, nil, errors.New("failed to read address").Base(err) } request.Address = addr request.Port = port responseAddress := s.address responsePort := s.port + var tempUDPConn *TempUDPConn //nolint:gocritic // Use if else chain for clarity if request.Command == protocol.RequestCommandUDP { if s.config.Address != nil { @@ -199,20 +203,34 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri // Use conn.LocalAddr() IP as remote address in the response by default responseAddress = s.localAddress } + udpHub, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: responseAddress.IP(), Port: 0}, nil) + if err != nil { + return nil, nil, errors.New("failed to create UDP listener").Base(err) + } + responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port) + expectedRemote := &gonet.UDPAddr{} + if request.Address.IP().IsUnspecified() { + expectedRemote.IP = writer.RemoteAddr().(*net.TCPAddr).IP // unix? + } else { + expectedRemote.IP = request.Address.IP() // panic? + expectedRemote.Port = int(request.Port) // 0 is allowed + } + tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemote) } if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil { - return nil, err + common.CloseIfExists(tempUDPConn) + return nil, nil, err } - return request, nil + return request, tempUDPConn, nil } // Handshake performs a Socks4/4a/5 handshake. -func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { +func (s *ServerSession) Handshake(reader io.Reader, writer net.Conn) (*protocol.RequestHeader, *TempUDPConn, error) { buffer := buf.StackNew() if _, err := buffer.ReadFullFrom(reader, 2); err != nil { buffer.Release() - return nil, errors.New("insufficient header").Base(err) + return nil, nil, errors.New("insufficient header").Base(err) } version := buffer.Byte(0) @@ -221,11 +239,12 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol switch version { case socks4Version: - return s.handshake4(cmd, reader, writer) + header, err := s.handshake4(cmd, reader, writer) + return header, nil, err case socks5Version: return s.handshake5(cmd, reader, writer) default: - return nil, errors.New("unknown Socks version: ", version) + return nil, nil, errors.New("unknown Socks version: ", version) } } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 478410f3..fab66476 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -29,7 +29,6 @@ type Server struct { config *ServerConfig policyManager policy.Manager cone bool - udpFilter *UDPFilter httpServer *http.Server } @@ -46,7 +45,6 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } if config.AuthType == AuthType_PASSWORD { httpConfig.Accounts = config.Accounts - s.udpFilter = new(UDPFilter) // We only use this when auth is enabled } s.httpServer, _ = http.NewServer(ctx, httpConfig) return s, nil @@ -60,11 +58,7 @@ func (s *Server) policy() policy.Session { // Network implements proxy.Inbound. func (s *Server) Network() []net.Network { - list := []net.Network{net.Network_TCP} - if s.config.UdpEnabled { - list = append(list, net.Network_UDP) - } - return list + return []net.Network{net.Network_TCP} } // Process implements proxy.Inbound. @@ -94,8 +88,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con return s.httpServer.ProcessWithFirstbyte(ctx, network, conn, dispatcher, firstbyte...) } return s.processTCP(ctx, conn, dispatcher, firstbyte) - case net.Network_UDP: - return s.handleUDPPayload(ctx, conn, dispatcher) default: return errors.New("unknown network: ", network) } @@ -126,7 +118,8 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche Reader: buf.NewReader(conn), Buffer: buf.MultiBuffer{buf.FromBytes(firstbyte)}, } - request, err := svrSession.Handshake(reader, conn) + request, tempUDPConn, err := svrSession.Handshake(reader, conn) + defer common.CloseIfExists(tempUDPConn) if err != nil { if inbound.Source.IsValid() { log.Record(&log.AccessMessage{ @@ -170,26 +163,25 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche } if request.Command == protocol.RequestCommandUDP { - if s.udpFilter != nil { - s.udpFilter.Add(conn.RemoteAddr()) + if tempUDPConn == nil { + return errors.New("UDP associate with listen port failed") } - return s.handleUDP(conn) + tempUDPConn.SetTimeout(plcy.Timeouts.ConnectionIdle) + errCh := make(chan error, 1) + go func() { + errCh <- s.handleUDPPayload(ctx, tempUDPConn, dispatcher) + }() + // Associated TCP keeps the UDP alive + // Close UDP if TCP connection is closed + // Or Close TCP if UDP is idle timeout + io.Copy(buf.DiscardBytes, conn) + tempUDPConn.Close() + return <-errCh } - return nil } -func (*Server) handleUDP(c io.Reader) error { - // The TCP connection closes after this method returns. We need to wait until - // the client closes it. - return common.Error2(io.Copy(buf.DiscardBytes, c)) -} - func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { - if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) { - errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String()) - return nil - } udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { payload := packet.Payload errors.LogDebug(ctx, "writing back UDP response with ", payload.Len(), " bytes") diff --git a/proxy/socks/temp_udp_listen.go b/proxy/socks/temp_udp_listen.go new file mode 100644 index 00000000..5e0dab03 --- /dev/null +++ b/proxy/socks/temp_udp_listen.go @@ -0,0 +1,72 @@ +package socks + +import ( + "context" + "net" + "sync/atomic" + "time" + + "github.com/xtls/xray-core/common/signal" +) + +func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, expectedRemote *net.UDPAddr) *TempUDPConn { + t := &TempUDPConn{ + PacketConn: udpConn, + AssociatedTCPConn: tcpConn, + } + t.ExpectedRemote.Store(expectedRemote) + return t +} + +// TempUDPConn wait for the first packet to determine the remote address +// SetTimeout MUST be called before any read/write operation +type TempUDPConn struct { + net.PacketConn + AssociatedTCPConn net.Conn + ExpectedRemote atomic.Pointer[net.UDPAddr] + Timer *signal.ActivityTimer +} + +func (c *TempUDPConn) Read(b []byte) (n int, err error) { + var remote net.Addr + for { + n, remote, err = c.PacketConn.ReadFrom(b) + if err != nil { + return + } + remote := remote.(*net.UDPAddr) + expected := c.ExpectedRemote.Load() + if remote.IP.Equal(expected.IP) { + if remote.Port == expected.Port { + c.Timer.Update() + return + } + if expected.Port == 0 { + c.ExpectedRemote.Store(remote) + c.Timer.Update() + return + } + } + } +} + +func (c *TempUDPConn) Write(b []byte) (n int, err error) { + c.Timer.Update() + return c.PacketConn.WriteTo(b, c.ExpectedRemote.Load()) +} + +func (c *TempUDPConn) RemoteAddr() net.Addr { + return c.ExpectedRemote.Load() +} + +func (c *TempUDPConn) SetTimeout(d time.Duration) { + c.Timer = signal.CancelAfterInactivity(context.Background(), func() { + c.Close() + }, d) +} + +func (c *TempUDPConn) Close() error { + c.Timer.SetTimeout(0) + c.AssociatedTCPConn.Close() + return c.PacketConn.Close() +} diff --git a/proxy/socks/udpfilter.go b/proxy/socks/udpfilter.go deleted file mode 100644 index 9ae3e697..00000000 --- a/proxy/socks/udpfilter.go +++ /dev/null @@ -1,31 +0,0 @@ -package socks - -import ( - "net" - "sync" -) - -/* -In the sock implementation of * ray, UDP authentication is flawed and can be bypassed. -Tracking a UDP connection may be a bit troublesome. -Here is a simple solution. -We create a filter, add remote IP to the pool when it try to establish a UDP connection with auth. -And drop UDP packets from unauthorized IP. -After discussion, we believe it is not necessary to add a timeout mechanism to this filter. -*/ - -type UDPFilter struct { - ips sync.Map -} - -func (f *UDPFilter) Add(addr net.Addr) bool { - ip, _, _ := net.SplitHostPort(addr.String()) - f.ips.Store(ip, true) - return true -} - -func (f *UDPFilter) Check(addr net.Addr) bool { - ip, _, _ := net.SplitHostPort(addr.String()) - _, ok := f.ips.Load(ip) - return ok -}