diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 6144f5c7..c124c1ee 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -4,6 +4,7 @@ import ( "context" goerrors "errors" "io" + "sync" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -26,6 +27,10 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0) type Server struct { bindServer *netBindServer + // Use a mutex-protected default routing info for forwarded connections + // Since we cannot determine which peer initiated a forwarded connection from gvisor, + // we use the most recently set routing info as default + infoMutex sync.RWMutex info routingInfo policyManager policy.Manager } @@ -78,12 +83,18 @@ func (*Server) Network() []net.Network { // Process implements proxy.Inbound. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { - s.info = routingInfo{ - ctx: ctx, - dispatcher: dispatcher, - inboundTag: session.InboundFromContext(ctx), - contentTag: session.ContentFromContext(ctx), + // Use RWMutex to safely handle concurrent access to routing info + // Only update if not set or if dispatcher is different + s.infoMutex.Lock() + if s.info.dispatcher == nil || s.info.dispatcher != dispatcher { + s.info = routingInfo{ + ctx: ctx, + dispatcher: dispatcher, + inboundTag: session.InboundFromContext(ctx), + contentTag: session.ContentFromContext(ctx), + } } + s.infoMutex.Unlock() ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) if err != nil { @@ -120,18 +131,23 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con } func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { - if s.info.dispatcher == nil { - errors.LogError(s.info.ctx, "unexpected: dispatcher == nil") + // Safely read routing info + s.infoMutex.RLock() + info := s.info + s.infoMutex.RUnlock() + + if info.dispatcher == nil { + errors.LogError(info.ctx, "unexpected: dispatcher == nil") return } defer conn.Close() - ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) + ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx)) sid := session.NewID() ctx = c.ContextWithID(ctx, sid) inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) - if s.info.inboundTag != nil { - inbound = *s.info.inboundTag + if info.inboundTag != nil { + inbound = *info.inboundTag } inbound.Name = "wireguard" inbound.CanSpliceCopy = 3 @@ -141,8 +157,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { // Currently we have no way to link to the original source address inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) - if s.info.contentTag != nil { - ctx = session.ContextWithContent(ctx, s.info.contentTag) + if info.contentTag != nil { + ctx = session.ContextWithContent(ctx, info.contentTag) } ctx = session.SubContextFromMuxInbound(ctx) @@ -156,7 +172,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { Reason: "", }) - link, err := s.info.dispatcher.Dispatch(ctx, dest) + link, err := info.dispatcher.Dispatch(ctx, dest) if err != nil { errors.LogErrorInner(ctx, err, "dispatch connection") }