diff --git a/app/router/strategy_leastload.go b/app/router/strategy_leastload.go index 1bf3cbc0..289982a8 100644 --- a/app/router/strategy_leastload.go +++ b/app/router/strategy_leastload.go @@ -3,6 +3,7 @@ package router import ( "context" "math" + "slices" "sort" "time" @@ -77,7 +78,7 @@ func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { } func (s *LeastLoadStrategy) pickOutbounds(candidates []string) []*node { - qualified := s.getNodes(candidates, time.Duration(s.settings.MaxRTT)) + qualified := s.getNodes(candidates) selects := s.selectLeastLoad(qualified) return selects } @@ -138,7 +139,7 @@ func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node { return nodes[:count] } -func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node { +func (s *LeastLoadStrategy) getNodes(candidates []string) []*node { if s.observer == nil { errors.LogError(s.ctx, "observer is nil") return make([]*node, 0) @@ -151,12 +152,10 @@ func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) results := observeResult.(*observatory.ObservationResult) - outboundlist := outboundList(candidates) - var ret []*node for _, v := range results.Status { - if v.Alive && (v.Delay < maxRTT.Milliseconds() || maxRTT == 0) && outboundlist.contains(v.OutboundTag) { + if s.shouldSelectNode(v, candidates) { record := &node{ Tag: v.OutboundTag, CountAll: 1, @@ -172,8 +171,8 @@ func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) record.RTTDeviationCost = time.Duration(s.costs.Apply(v.OutboundTag, float64(v.HealthPing.Deviation))) record.CountAll = int(v.HealthPing.All) record.CountFail = int(v.HealthPing.Fail) - } + ret = append(ret, record) } } @@ -182,6 +181,23 @@ func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) return ret } +func (s *LeastLoadStrategy) shouldSelectNode(v *observatory.OutboundStatus, candidates []string) bool { + maxRTT := time.Duration(s.settings.MaxRTT) + if !v.Alive { + return false + } + if maxRTT != 0 && v.Delay >= maxRTT.Milliseconds() { + return false + } + if !slices.Contains(candidates, v.OutboundTag) { + return false + } + if v.HealthPing != nil && v.HealthPing.All > 0 && s.settings.Tolerance > 0 && float64(v.HealthPing.Fail)/float64(v.HealthPing.All) > float64(s.settings.Tolerance) { + return false + } + return true +} + func leastloadSort(nodes []*node) { sort.Slice(nodes, func(i, j int) bool { left := nodes[i]