From e1930e6a18f61d32ecc8e21750cd5c02b4734a67 Mon Sep 17 00:00:00 2001 From: zakuwaki <79925675+zakuwaki@users.noreply.github.com> Date: Tue, 4 Jul 2023 10:42:21 +0800 Subject: [PATCH] feat: add bandwidth limiter --- adapter/router.go | 1 + box.go | 4 ++ limiter/builder.go | 103 +++++++++++++++++++++++++++++++++++++++++ limiter/limiter.go | 77 ++++++++++++++++++++++++++++++ limiter/manager.go | 11 +++++ option/config.go | 1 + option/limiter.go | 9 ++++ option/rule.go | 10 ++-- route/router.go | 18 +++++++ route/rule_abstract.go | 10 ++++ route/rule_default.go | 6 +++ 11 files changed, 246 insertions(+), 4 deletions(-) create mode 100644 limiter/builder.go create mode 100644 limiter/limiter.go create mode 100644 limiter/manager.go create mode 100644 option/limiter.go diff --git a/adapter/router.go b/adapter/router.go index 3cf9e6d4b1..f146d4a587 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -76,6 +76,7 @@ type Rule interface { Match(metadata *InboundContext) bool Outbound() string String() string + Limiters() []string } type DNSRule interface { diff --git a/box.go b/box.go index 3ceb7a55d0..a48f514b65 100644 --- a/box.go +++ b/box.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/inbound" + "github.com/sagernet/sing-box/limiter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/outbound" @@ -72,6 +73,9 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "create log factory") } + if len(options.Limiters) > 0 { + ctx = limiter.WithDefault(ctx, logFactory.NewLogger("limiter"), options.Limiters) + } router, err := route.NewRouter( ctx, logFactory, diff --git a/limiter/builder.go b/limiter/builder.go new file mode 100644 index 0000000000..b042fa3e88 --- /dev/null +++ b/limiter/builder.go @@ -0,0 +1,103 @@ +package limiter + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/dustin/go-humanize" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" +) + +const ( + limiterDefault = "default" + limiterUser = "user" + limiterInbound = "inbound" +) + +var _ Manager = (*defaultManager)(nil) + +type defaultManager struct { + mp *sync.Map +} + +func WithDefault(ctx context.Context, logger log.ContextLogger, options []option.Limiter) context.Context { + m := &defaultManager{mp: &sync.Map{}} + for i, option := range options { + if err := m.createLimiter(ctx, option); err != nil { + logger.ErrorContext(ctx, fmt.Sprintf("id=%d, %s", i, err)) + } else { + logger.InfoContext(ctx, fmt.Sprintf("id=%d, tag=%s, users=%v, inbounds=%v, download=%s, upload=%s", + i, option.Tag, option.AuthUser, option.Inbound, option.Download, option.Upload)) + } + } + return service.ContextWith[Manager](ctx, m) +} + +func buildKey(prefix string, tag string) string { + return fmt.Sprintf("%s-%s", prefix, tag) +} + +func (m *defaultManager) createLimiter(ctx context.Context, option option.Limiter) (err error) { + var download, upload uint64 + if len(option.Download) > 0 { + download, err = humanize.ParseBytes(option.Download) + if err != nil { + return err + } + } + if len(option.Upload) > 0 { + upload, err = humanize.ParseBytes(option.Upload) + } + if download == 0 && upload == 0 { + return E.New("bandwith must be set") + } + l := newLimiter(download, upload) + valid := false + if len(option.Tag) > 0 { + valid = true + m.mp.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload)) + } + if len(option.AuthUser) > 0 { + valid = true + for _, user := range option.AuthUser { + m.mp.Store(buildKey(limiterUser, user), l) + } + } + if len(option.Inbound) > 0 { + valid = true + for _, inbound := range option.Inbound { + m.mp.Store(buildKey(limiterInbound, inbound), l) + } + } + if !valid { + return E.New("tag or constraint must be set") + } + return +} + +func (m *defaultManager) LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) { + for _, t := range tags { + if v, ok := m.mp.Load(buildKey(limiterDefault, t)); ok { + limiters = append(limiters, v.(*limiter)) + } + } + if v, ok := m.mp.Load(buildKey(limiterUser, user)); ok { + limiters = append(limiters, v.(*limiter)) + } + if v, ok := m.mp.Load(buildKey(limiterInbound, inbound)); ok { + limiters = append(limiters, v.(*limiter)) + } + return +} + +func (m *defaultManager) NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn { + for _, limiter := range limiters { + conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx} + } + return conn +} diff --git a/limiter/limiter.go b/limiter/limiter.go new file mode 100644 index 0000000000..8679c558ae --- /dev/null +++ b/limiter/limiter.go @@ -0,0 +1,77 @@ +package limiter + +import ( + "context" + "net" + + "golang.org/x/time/rate" +) + +type limiter struct { + downloadLimiter *rate.Limiter + uploadLimiter *rate.Limiter +} + +func newLimiter(download, upload uint64) *limiter { + var downloadLimiter, uploadLimiter *rate.Limiter + if download > 0 { + downloadLimiter = rate.NewLimiter(rate.Limit(float64(download)), int(download)) + } + if upload > 0 { + uploadLimiter = rate.NewLimiter(rate.Limit(float64(upload)), int(upload)) + } + return &limiter{downloadLimiter: downloadLimiter, uploadLimiter: uploadLimiter} +} + +type connWithLimiter struct { + net.Conn + limiter *limiter + ctx context.Context +} + +func (conn *connWithLimiter) Read(p []byte) (n int, err error) { + if conn.limiter == nil || conn.limiter.downloadLimiter == nil { + return conn.Conn.Read(p) + } + b := conn.limiter.downloadLimiter.Burst() + if b < len(p) { + p = p[:b] + } + n, err = conn.Conn.Read(p) + if err != nil { + return + } + err = conn.limiter.downloadLimiter.WaitN(conn.ctx, n) + if err != nil { + return + } + return +} + +func (conn *connWithLimiter) Write(p []byte) (n int, err error) { + if conn.limiter == nil || conn.limiter.uploadLimiter == nil { + return conn.Conn.Write(p) + } + var nn int + b := conn.limiter.uploadLimiter.Burst() + for { + end := len(p) + if end == 0 { + break + } + if b < len(p) { + end = b + } + err = conn.limiter.uploadLimiter.WaitN(conn.ctx, end) + if err != nil { + return + } + nn, err = conn.Conn.Write(p[:end]) + n += nn + if err != nil { + return + } + p = p[end:] + } + return +} diff --git a/limiter/manager.go b/limiter/manager.go new file mode 100644 index 0000000000..4521393074 --- /dev/null +++ b/limiter/manager.go @@ -0,0 +1,11 @@ +package limiter + +import ( + "context" + "net" +) + +type Manager interface { + LoadLimiters(tags []string, user, inbound string) []*limiter + NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn +} diff --git a/option/config.go b/option/config.go index ec471112e7..ddb9ddec10 100644 --- a/option/config.go +++ b/option/config.go @@ -16,6 +16,7 @@ type _Options struct { Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` Route *RouteOptions `json:"route,omitempty"` + Limiters []Limiter `json:"limiters,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"` } diff --git a/option/limiter.go b/option/limiter.go new file mode 100644 index 0000000000..99f1dc629b --- /dev/null +++ b/option/limiter.go @@ -0,0 +1,9 @@ +package option + +type Limiter struct { + Tag string `json:"tag"` + Download string `json:"download,omitempty"` + Upload string `json:"upload,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Inbound Listable[string] `json:"inbound,omitempty"` +} diff --git a/option/rule.go b/option/rule.go index f78a752d91..c813eea92e 100644 --- a/option/rule.go +++ b/option/rule.go @@ -80,6 +80,7 @@ type DefaultRule struct { ClashMode string `json:"clash_mode,omitempty"` Invert bool `json:"invert,omitempty"` Outbound string `json:"outbound,omitempty"` + Limiter Listable[string] `json:"limiter,omitempty"` } func (r DefaultRule) IsValid() bool { @@ -90,10 +91,11 @@ func (r DefaultRule) IsValid() bool { } type LogicalRule struct { - Mode string `json:"mode"` - Rules []DefaultRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` + Mode string `json:"mode"` + Rules []DefaultRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` + Limiter Listable[string] `json:"limiter,omitempty"` } func (r LogicalRule) IsValid() bool { diff --git a/route/router.go b/route/router.go index c23c8a458e..1e978aa10d 100644 --- a/route/router.go +++ b/route/router.go @@ -20,6 +20,7 @@ import ( "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/limiter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/ntp" "github.com/sagernet/sing-box/option" @@ -38,6 +39,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" + "github.com/sagernet/sing/service" ) var _ adapter.Router = (*Router)(nil) @@ -80,6 +82,7 @@ type Router struct { timeService adapter.TimeService clashServer adapter.ClashServer v2rayServer adapter.V2RayServer + limiterManager limiter.Manager platformInterface platform.Interface } @@ -487,6 +490,9 @@ func (r *Router) Start() error { return E.Cause(err, "initialize time service") } } + if limiterManger := service.FromContext[limiter.Manager](r.ctx); limiterManger != nil { + r.limiterManager = limiterManger + } return nil } @@ -688,6 +694,18 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad if !common.Contains(detour.Network(), N.NetworkTCP) { return E.New("missing supported outbound, closing connection") } + + if r.limiterManager != nil { + var limiterTags []string + if matchedRule != nil { + limiterTags = matchedRule.Limiters() + } + limiters := r.limiterManager.LoadLimiters(limiterTags, metadata.User, metadata.Inbound) + if len(limiters) > 0 { + conn = r.limiterManager.NewConnWithLimiters(ctx, conn, limiters) + } + } + if r.clashServer != nil { trackerConn, tracker := r.clashServer.RoutedConnection(ctx, conn, metadata, matchedRule) defer tracker.Leave() diff --git a/route/rule_abstract.go b/route/rule_abstract.go index 38d4d57d41..bbbbe99f02 100644 --- a/route/rule_abstract.go +++ b/route/rule_abstract.go @@ -18,6 +18,7 @@ type abstractDefaultRule struct { allItems []RuleItem invert bool outbound string + limiters []string } func (r *abstractDefaultRule) Type() string { @@ -126,6 +127,10 @@ func (r *abstractDefaultRule) Outbound() string { return r.outbound } +func (r *abstractDefaultRule) Limiters() []string { + return r.limiters +} + func (r *abstractDefaultRule) String() string { if !r.invert { return strings.Join(F.MapToString(r.allItems), " ") @@ -139,6 +144,7 @@ type abstractLogicalRule struct { mode string invert bool outbound string + limiters []string } func (r *abstractLogicalRule) Type() string { @@ -191,6 +197,10 @@ func (r *abstractLogicalRule) Outbound() string { return r.outbound } +func (r *abstractLogicalRule) Limiters() []string { + return r.limiters +} + func (r *abstractLogicalRule) String() string { var op string switch r.mode { diff --git a/route/rule_default.go b/route/rule_default.go index 01322c13aa..780fd8cc7c 100644 --- a/route/rule_default.go +++ b/route/rule_default.go @@ -184,6 +184,9 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + if len(options.Limiter) > 0 { + rule.limiters = append(rule.limiters, options.Limiter...) + } return rule, nil } @@ -216,5 +219,8 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt } r.rules[i] = rule } + if len(options.Limiter) > 0 { + r.limiters = append(r.limiters, options.Limiter...) + } return r, nil }