diff --git a/common/geodata/domain_matcher.go b/common/geodata/domain_matcher.go index be9e62bf..2c231d98 100644 --- a/common/geodata/domain_matcher.go +++ b/common/geodata/domain_matcher.go @@ -8,6 +8,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/geodata/strmatcher" + "github.com/xtls/xray-core/common/utils" ) type DomainMatcher interface { @@ -25,7 +26,7 @@ type DomainMatcherFactory interface { type MphDomainMatcherFactory struct { sync.Mutex - shared map[string]strmatcher.MatcherGroup // TODO: cleanup + shared *utils.WeakCacheMap[string, strmatcher.MphValueMatcher] } func buildDomainRulesKey(rules []*DomainRule) string { @@ -65,7 +66,7 @@ func (f *MphDomainMatcherFactory) BuildMatcher(rules []*DomainRule) (DomainMatch if key != "" { f.Lock() defer f.Unlock() - if g := f.shared[key]; g != nil { + if g, ok := f.shared.Load(key); ok { errors.LogDebug(context.Background(), "geodata mph domain matcher cache HIT for ", len(rules), " rules") return g, nil } @@ -102,14 +103,14 @@ func (f *MphDomainMatcherFactory) BuildMatcher(rules []*DomainRule) (DomainMatch return nil, err } if key != "" { - f.shared[key] = g + f.shared.Store(key, g) } return g, nil } type CompactDomainMatcherFactory struct { sync.Mutex - shared map[string]strmatcher.MatcherSet // TODO: cleanup + shared *utils.WeakCacheMap[string, strmatcher.LinearAnyMatcher] } func (f *CompactDomainMatcherFactory) getOrCreateFrom(rule *GeoSiteRule) (strmatcher.MatcherSet, error) { @@ -118,7 +119,7 @@ func (f *CompactDomainMatcherFactory) getOrCreateFrom(rule *GeoSiteRule) (strmat f.Lock() defer f.Unlock() - if s := f.shared[key]; s != nil { + if s, ok := f.shared.Load(key); ok { errors.LogDebug(context.Background(), "geodata geosite matcher cache HIT ", key) return s, nil } @@ -138,7 +139,7 @@ func (f *CompactDomainMatcherFactory) getOrCreateFrom(rule *GeoSiteRule) (strmat } s.Add(m) } - f.shared[key] = s + f.shared.Store(key, s) return s, err } @@ -230,8 +231,8 @@ func parseDomain(d *Domain) (strmatcher.Matcher, error) { func newDomainMatcherFactory() DomainMatcherFactory { switch runtime.GOOS { case "ios", "android": - return &CompactDomainMatcherFactory{shared: make(map[string]strmatcher.MatcherSet)} + return &CompactDomainMatcherFactory{shared: utils.NewWeakCacheMap[string, strmatcher.LinearAnyMatcher]()} default: - return &MphDomainMatcherFactory{shared: make(map[string]strmatcher.MatcherGroup)} + return &MphDomainMatcherFactory{shared: utils.NewWeakCacheMap[string, strmatcher.MphValueMatcher]()} } } diff --git a/common/geodata/ip_matcher.go b/common/geodata/ip_matcher.go index 1eba5dbf..703b6497 100644 --- a/common/geodata/ip_matcher.go +++ b/common/geodata/ip_matcher.go @@ -11,6 +11,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/utils" "go4.org/netipx" ) @@ -806,7 +807,7 @@ func (mm *HeuristicMultiIPMatcher) SetReverse(reverse bool) { type IPSetFactory struct { sync.Mutex - shared map[string]*IPSet // TODO: cleanup + shared *utils.WeakCacheMap[string, IPSet] } func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, error) { @@ -815,7 +816,7 @@ func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, er f.Lock() defer f.Unlock() - if ipset := f.shared[key]; ipset != nil { + if ipset, ok := f.shared.Load(key); ok { errors.LogDebug(context.Background(), "geodata geoip matcher cache HIT ", key) return ipset, nil } @@ -835,7 +836,7 @@ func (f *IPSetFactory) GetOrCreateFromGeoIPRules(rules []*GeoIPRule) (*IPSet, er return nil }) if err == nil { - f.shared[key] = ipset + f.shared.Store(key, ipset) } return ipset, err } @@ -1018,5 +1019,5 @@ func buildOptimizedIPMatcher(f *IPSetFactory, rules []*IPRule) (IPMatcher, error } func newIPSetFactory() *IPSetFactory { - return &IPSetFactory{shared: make(map[string]*IPSet)} + return &IPSetFactory{shared: utils.NewWeakCacheMap[string, IPSet]()} } diff --git a/common/utils/weak_cache.go b/common/utils/weak_cache.go new file mode 100644 index 00000000..c14e1912 --- /dev/null +++ b/common/utils/weak_cache.go @@ -0,0 +1,45 @@ +package utils + +import ( + "runtime" + "sync" + "weak" +) + +// WeakCacheMap is a map that holds weak references to values. +// Use for shared expensive objects and automatic cleanup when no longer used. +// This object can be GC and no goroutine is used for cleanup. +type WeakCacheMap[K comparable, V any] struct { + mu sync.Mutex + m map[K]weak.Pointer[V] +} + +func NewWeakCacheMap[K comparable, V any]() *WeakCacheMap[K, V] { + return &WeakCacheMap[K, V]{ + m: make(map[K]weak.Pointer[V]), + } +} + +func (c *WeakCacheMap[K, V]) Load(key K) (value *V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + weakPtr := c.m[key].Value() + if weakPtr != nil { + return weakPtr, true + } + return nil, false +} + +func (c *WeakCacheMap[K, V]) Store(key K, value *V) { + c.mu.Lock() + defer c.mu.Unlock() + weakPtr := weak.Make(value) + c.m[key] = weakPtr + runtime.AddCleanup(value, func(any) { + c.mu.Lock() + defer c.mu.Unlock() + if c.m[key] == weakPtr { + delete(c.m, key) + } + }, nil) +}