Skip to content

Commit

Permalink
feat: limiter in service
Browse files Browse the repository at this point in the history
  • Loading branch information
zakuwaki committed Jun 29, 2023
1 parent 62af47a commit 0bd5ffd
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 41 deletions.
7 changes: 3 additions & 4 deletions box.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, E.Cause(err, "create log factory")
}
if len(options.Limiters) > 0 {
ctx = limiter.WithDefault(ctx, logFactory.NewLogger("limiter"), options.Limiters)
}
router, err := route.NewRouter(
ctx,
logFactory,
Expand Down Expand Up @@ -135,10 +138,6 @@ func New(options Options) (*Box, error) {
if err != nil {
return nil, err
}
err = limiter.New(ctx, logFactory.NewLogger("limiter"), options.Limiters)
if err != nil {
return nil, err
}
if options.PlatformInterface != nil {
err = options.PlatformInterface.Initialize(ctx, router)
if err != nil {
Expand Down
60 changes: 37 additions & 23 deletions limiter/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package limiter
import (
"context"
"fmt"
"net"
"sync"

"github.com/dustin/go-humanize"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/service"
)

const (
Expand All @@ -17,19 +19,30 @@ const (
limiterInbound = "inbound"
)

var m sync.Map
var _ Manager = (*defaultManager)(nil)

func New(ctx context.Context, logger log.ContextLogger, options []option.Limiter) (err error) {
for _, option := range options {
err = new(ctx, logger, option)
if err != nil {
return
type defaultManager struct {
mp *sync.Map
}

func WithDefault(ctx context.Context, logger log.ContextLogger, options []option.Limiter) context.Context {
m := &defaultManager{mp: &sync.Map{}}
for i, option := range options {
if err := m.createLimiter(ctx, option); err != nil {
logger.ErrorContext(ctx, fmt.Sprintf("id=%d, %s", i, err))
} else {
logger.InfoContext(ctx, fmt.Sprintf("id=%d, tag=%s, users=%v, inbounds=%v, download=%s, upload=%s",
i, option.Tag, option.AuthUser, option.Inbound, option.Download, option.Upload))
}
}
return
return service.ContextWith[Manager](ctx, m)
}

func buildKey(prefix string, tag string) string {
return fmt.Sprintf("%s-%s", prefix, tag)
}

func new(ctx context.Context, logger log.ContextLogger, option option.Limiter) (err error) {
func (m *defaultManager) createLimiter(ctx context.Context, option option.Limiter) (err error) {
var download, upload uint64
if len(option.Download) > 0 {
download, err = humanize.ParseBytes(option.Download)
Expand All @@ -41,49 +54,50 @@ func new(ctx context.Context, logger log.ContextLogger, option option.Limiter) (
upload, err = humanize.ParseBytes(option.Upload)
}
if download == 0 && upload == 0 {
return E.New("limiter bandwith must be set")
return E.New("bandwith must be set")
}
l := newLimiter(download, upload)
valid := false
if len(option.Tag) > 0 {
valid = true
m.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload))
m.mp.Store(buildKey(limiterDefault, option.Tag), newLimiter(download, upload))
}
if len(option.AuthUser) > 0 {
valid = true
for _, user := range option.AuthUser {
m.Store(buildKey(limiterUser, user), l)
m.mp.Store(buildKey(limiterUser, user), l)
}
}
if len(option.Inbound) > 0 {
valid = true
for _, inbound := range option.Inbound {
m.Store(buildKey(limiterInbound, inbound), l)
m.mp.Store(buildKey(limiterInbound, inbound), l)
}
}
if !valid {
return E.New("limiter tag or constraint must be set")
return E.New("tag or constraint must be set")
}
logger.InfoContext(ctx, fmt.Sprintf("limiter created, download:%s, upload:%s, tag:%s, users:%v, inbounds:%v",
option.Download, option.Upload, option.Tag, option.AuthUser, option.Inbound))
return
}

func buildKey(prefix string, tag string) string {
return fmt.Sprintf("%s-%s", prefix, tag)
}

func LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) {
func (m *defaultManager) LoadLimiters(tags []string, user, inbound string) (limiters []*limiter) {
for _, t := range tags {
if v, ok := m.Load(buildKey(limiterDefault, t)); ok {
if v, ok := m.mp.Load(buildKey(limiterDefault, t)); ok {
limiters = append(limiters, v.(*limiter))
}
}
if v, ok := m.Load(buildKey(limiterUser, user)); ok {
if v, ok := m.mp.Load(buildKey(limiterUser, user)); ok {
limiters = append(limiters, v.(*limiter))
}
if v, ok := m.Load(buildKey(limiterInbound, inbound)); ok {
if v, ok := m.mp.Load(buildKey(limiterInbound, inbound)); ok {
limiters = append(limiters, v.(*limiter))
}
return
}

func (m *defaultManager) NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn {
for _, limiter := range limiters {
conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx}
}
return conn
}
7 changes: 0 additions & 7 deletions limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ type connWithLimiter struct {
ctx context.Context
}

func NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn {
for _, limiter := range limiters {
conn = &connWithLimiter{Conn: conn, limiter: limiter, ctx: ctx}
}
return conn
}

func (conn *connWithLimiter) Read(p []byte) (n int, err error) {
if conn.limiter.downloadLimiter == nil {
return conn.Conn.Read(p)
Expand Down
11 changes: 11 additions & 0 deletions limiter/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package limiter

import (
"context"
"net"
)

type Manager interface {
LoadLimiters(tags []string, user, inbound string) []*limiter
NewConnWithLimiters(ctx context.Context, conn net.Conn, limiters []*limiter) net.Conn
}
21 changes: 14 additions & 7 deletions route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)

var _ adapter.Router = (*Router)(nil)
Expand Down Expand Up @@ -81,6 +82,7 @@ type Router struct {
timeService adapter.TimeService
clashServer adapter.ClashServer
v2rayServer adapter.V2RayServer
limiterManager limiter.Manager
platformInterface platform.Interface
}

Expand Down Expand Up @@ -488,6 +490,9 @@ func (r *Router) Start() error {
return E.Cause(err, "initialize time service")
}
}
if limiterManger := service.FromContext[limiter.Manager](r.ctx); limiterManger != nil {
r.limiterManager = limiterManger
}
return nil
}

Expand Down Expand Up @@ -690,13 +695,15 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
return E.New("missing supported outbound, closing connection")
}

var limiterTags []string
if matchedRule != nil {
limiterTags = matchedRule.Limiters()
}
limiters := limiter.LoadLimiters(limiterTags, metadata.User, metadata.Inbound)
if len(limiters) > 0 {
conn = limiter.NewConnWithLimiters(ctx, conn, limiters)
if r.limiterManager != nil {
var limiterTags []string
if matchedRule != nil {
limiterTags = matchedRule.Limiters()
}
limiters := r.limiterManager.LoadLimiters(limiterTags, metadata.User, metadata.Inbound)
if len(limiters) > 0 {
conn = r.limiterManager.NewConnWithLimiters(ctx, conn, limiters)
}
}

if r.clashServer != nil {
Expand Down

0 comments on commit 0bd5ffd

Please sign in to comment.