From 385867e82b26d6a727527045395d9a76dfbcbb4b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:28:10 +0000 Subject: [PATCH] Fix race condition in WireGuard server with concurrent peer connections Add mutex protection to server.go to prevent race condition when multiple peers connect simultaneously. The shared routingInfo field was being overwritten by concurrent Process() calls, causing connections to fail. - Add sync.RWMutex to protect access to routing info - Only update routing info if not already set or dispatcher changed - Use local copy of routing info in forwardConnection to avoid races - Existing tests pass Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com> --- proxy/wireguard/server.go | 42 +++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 6144f5c7..c124c1ee 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -4,6 +4,7 @@ import ( "context" goerrors "errors" "io" + "sync" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -26,6 +27,10 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0) type Server struct { bindServer *netBindServer + // Use a mutex-protected default routing info for forwarded connections + // Since we cannot determine which peer initiated a forwarded connection from gvisor, + // we use the most recently set routing info as default + infoMutex sync.RWMutex info routingInfo policyManager policy.Manager } @@ -78,12 +83,18 @@ func (*Server) Network() []net.Network { // Process implements proxy.Inbound. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { - s.info = routingInfo{ - ctx: ctx, - dispatcher: dispatcher, - inboundTag: session.InboundFromContext(ctx), - contentTag: session.ContentFromContext(ctx), + // Use RWMutex to safely handle concurrent access to routing info + // Only update if not set or if dispatcher is different + s.infoMutex.Lock() + if s.info.dispatcher == nil || s.info.dispatcher != dispatcher { + s.info = routingInfo{ + ctx: ctx, + dispatcher: dispatcher, + inboundTag: session.InboundFromContext(ctx), + contentTag: session.ContentFromContext(ctx), + } } + s.infoMutex.Unlock() ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) if err != nil { @@ -120,18 +131,23 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con } func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { - if s.info.dispatcher == nil { - errors.LogError(s.info.ctx, "unexpected: dispatcher == nil") + // Safely read routing info + s.infoMutex.RLock() + info := s.info + s.infoMutex.RUnlock() + + if info.dispatcher == nil { + errors.LogError(info.ctx, "unexpected: dispatcher == nil") return } defer conn.Close() - ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) + ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx)) sid := session.NewID() ctx = c.ContextWithID(ctx, sid) inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) - if s.info.inboundTag != nil { - inbound = *s.info.inboundTag + if info.inboundTag != nil { + inbound = *info.inboundTag } inbound.Name = "wireguard" inbound.CanSpliceCopy = 3 @@ -141,8 +157,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { // Currently we have no way to link to the original source address inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) - if s.info.contentTag != nil { - ctx = session.ContextWithContent(ctx, s.info.contentTag) + if info.contentTag != nil { + ctx = session.ContextWithContent(ctx, info.contentTag) } ctx = session.SubContextFromMuxInbound(ctx) @@ -156,7 +172,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { Reason: "", }) - link, err := s.info.dispatcher.Dispatch(ctx, dest) + link, err := info.dispatcher.Dispatch(ctx, dest) if err != nil { errors.LogErrorInner(ctx, err, "dispatch connection") }