Skip to content

Commit

Permalink
feat: add NTP configuration and improve middleware handling
Browse files Browse the repository at this point in the history
  • Loading branch information
divyam234 committed Jan 1, 2025
1 parent 055cc01 commit a296b36
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 332 deletions.
1 change: 1 addition & 0 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func NewRun() *cobra.Command {
runCmd.Flags().StringVar(&config.TG.LangPack, "tg-lang-pack", "webk", "Language pack")
runCmd.Flags().StringVar(&config.TG.Proxy, "tg-proxy", "", "HTTP OR SOCKS5 proxy URL")
runCmd.Flags().BoolVar(&config.TG.DisableStreamBots, "tg-disable-stream-bots", false, "Disable Stream bots")
runCmd.Flags().BoolVar(&config.TG.Ntp, "tg-ntp", false, "Use NTP server time")
runCmd.Flags().BoolVar(&config.TG.EnableLogging, "tg-enable-logging", false, "Enable telegram client logging")
runCmd.Flags().StringVar(&config.TG.Uploads.EncryptionKey, "tg-uploads-encryption-key", "", "Uploads encryption key")
runCmd.Flags().IntVar(&config.TG.Uploads.Threads, "tg-uploads-threads", 8, "Uploads threads")
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type TGConfig struct {
LangCode string
SystemLangCode string
LangPack string
Ntp bool
SessionFile string
DisableStreamBots bool
BgBotsCheckInterval time.Duration
Expand Down
3 changes: 2 additions & 1 deletion internal/tgc/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ func GetMediaContent(ctx context.Context, client *tg.Client, location tg.InputFi

func GetBotInfo(ctx context.Context, KV kv.KV, config *config.TGConfig, token string) (*types.BotInfo, error) {
var user *tg.User
client, _ := BotClient(ctx, KV, config, token, Middlewares(config, 5)...)
middlewares := NewMiddleware(config, WithFloodWait(), WithRateLimit())
client, _ := BotClient(ctx, KV, config, token, middlewares...)
err := RunWithAuth(ctx, client, token, func(ctx context.Context) error {
user, _ = client.Self(ctx)
return nil
Expand Down
62 changes: 49 additions & 13 deletions internal/tgc/tgc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa
logger = logging.FromContext(ctx).Named("td")

}
c, err := clock.NewNTP()
if err != nil {
return nil, errors.Wrap(err, "create clock")
}

opts := telegram.Options{
Resolver: dcs.Plain(dcs.PlainOptions{
Expand All @@ -66,7 +62,14 @@ func New(ctx context.Context, config *config.TGConfig, handler telegram.UpdateHa
Middlewares: middlewares,
UpdateHandler: handler,
Logger: logger,
Clock: c,
}
if config.Ntp {
c, err := clock.NewNTP()
if err != nil {
return nil, errors.Wrap(err, "create clock")
}
opts.Clock = c

}

return telegram.NewClient(config.AppId, config.AppHash, opts), nil
Expand Down Expand Up @@ -106,17 +109,50 @@ func BotClient(ctx context.Context, KV kv.KV, config *config.TGConfig, token str

}

func Middlewares(config *config.TGConfig, retries int) []telegram.Middleware {
middlewares := []telegram.Middleware{
floodwait.NewSimpleWaiter(),
recovery.New(context.Background(), newBackoff(config.ReconnectTimeout)),
retry.New(retries),
type middlewareOption func(*middlewareConfig)

type middlewareConfig struct {
config *config.TGConfig
middlewares []telegram.Middleware
}

func NewMiddleware(config *config.TGConfig, opts ...middlewareOption) []telegram.Middleware {
mc := &middlewareConfig{
config: config,
middlewares: []telegram.Middleware{},
}
if config.RateLimit {
middlewares = append(middlewares, ratelimit.New(rate.Every(time.Millisecond*time.Duration(config.Rate)), config.RateBurst))
for _, opt := range opts {
opt(mc)
}
return middlewares
return mc.middlewares
}

func WithFloodWait() middlewareOption {
return func(mc *middlewareConfig) {
mc.middlewares = append(mc.middlewares, floodwait.NewSimpleWaiter())
}
}

func WithRecovery(ctx context.Context) middlewareOption {
return func(mc *middlewareConfig) {
mc.middlewares = append(mc.middlewares,
recovery.New(ctx, newBackoff(mc.config.ReconnectTimeout)))
}
}

func WithRetry(retries int) middlewareOption {
return func(mc *middlewareConfig) {
mc.middlewares = append(mc.middlewares, retry.New(retries))
}
}

func WithRateLimit() middlewareOption {
return func(mc *middlewareConfig) {
if mc.config.RateLimit {
mc.middlewares = append(mc.middlewares,
ratelimit.New(rate.Every(time.Millisecond*time.Duration(mc.config.Rate)), mc.config.RateBurst))
}
}
}

func newBackoff(timeout time.Duration) backoff.BackOff {
Expand Down
146 changes: 0 additions & 146 deletions internal/tgc/workers.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
package tgc

import (
"context"
"strings"
"sync"
"time"

"github.com/gotd/td/telegram"
"github.com/tgdrive/teldrive/internal/config"
"github.com/tgdrive/teldrive/internal/kv"
"go.uber.org/zap"
)

type BotWorker struct {
Expand Down Expand Up @@ -43,141 +35,3 @@ func (w *BotWorker) Next(channelID int64) (string, int) {
w.currIdx[channelID] = (index + 1) % len(bots)
return bots[index], index
}

type ClientStatus int

const (
StatusIdle ClientStatus = iota
StatusBusy
)

type Client struct {
Tg *telegram.Client
Stop StopFunc
Status ClientStatus
UserID string
}

type StreamWorker struct {
mu sync.RWMutex
clients map[string]*Client
currIdx map[int64]int
channelBots map[int64][]string
cnf *config.TGConfig
kv kv.KV
ctx context.Context
logger *zap.SugaredLogger
activeStreams int
cancel context.CancelFunc
}

func NewStreamWorker(cnf *config.Config, kv kv.KV, logger *zap.SugaredLogger) *StreamWorker {
ctx, cancel := context.WithCancel(context.Background())
worker := &StreamWorker{
cnf: &cnf.TG,
kv: kv,
ctx: ctx,
clients: make(map[string]*Client),
currIdx: make(map[int64]int),
channelBots: make(map[int64][]string),
logger: logger,
cancel: cancel,
}
go worker.startIdleClientMonitor()
return worker

}

func (w *StreamWorker) Set(bots []string, channelID int64) {
w.mu.Lock()
defer w.mu.Unlock()
w.channelBots[channelID] = bots
w.currIdx[channelID] = 0
}

func (w *StreamWorker) Next(channelID int64) (*Client, error) {
w.mu.Lock()
defer w.mu.Unlock()

bots := w.channelBots[channelID]
index := w.currIdx[channelID]
token := bots[index]
userID := strings.Split(token, ":")[0]

client, err := w.getOrCreateClient(userID, token)
if err != nil {
return nil, err
}

w.currIdx[channelID] = (index + 1) % len(bots)

return client, nil
}

func (w *StreamWorker) IncActiveStream() error {
w.mu.Lock()
defer w.mu.Unlock()

w.activeStreams++
return nil
}

func (w *StreamWorker) DecActiveStreams() error {
w.mu.Lock()
defer w.mu.Unlock()

if w.activeStreams == 0 {
return nil
}
w.activeStreams--
return nil
}

func (w *StreamWorker) getOrCreateClient(userID, token string) (*Client, error) {
client, ok := w.clients[userID]
if !ok || (client.Status == StatusIdle && client.Stop == nil) {
middlewares := Middlewares(w.cnf, 5)
tgClient, _ := BotClient(w.ctx, w.kv, w.cnf, token, middlewares...)
client = &Client{Tg: tgClient, Status: StatusIdle, UserID: userID}
w.clients[userID] = client
stop, err := Connect(client.Tg, WithBotToken(token))
if err != nil {
return nil, err
}
client.Stop = stop
client.Status = StatusBusy
w.logger.Debug("started bg client: ", userID)
}
return client, nil
}

func (w *StreamWorker) startIdleClientMonitor() {
ticker := time.NewTicker(w.cnf.BgBotsCheckInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
w.checkIdleClients()
case <-w.ctx.Done():
return
}
}
}

func (w *StreamWorker) checkIdleClients() {
w.mu.Lock()
defer w.mu.Unlock()
if w.activeStreams == 0 {
for _, client := range w.clients {
if client.Status == StatusBusy && client.Stop != nil {
client.Stop()
client.Stop = nil
client.Tg = nil
client.Status = StatusIdle
w.logger.Debug("stopped bg client: ", client.UserID)
}
}
}

}
15 changes: 9 additions & 6 deletions pkg/services/api_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/go-faster/errors"
"github.com/gotd/td/telegram"
"github.com/ogen-go/ogen/ogenerrors"

ht "github.com/ogen-go/ogen/http"
Expand All @@ -17,11 +18,12 @@ import (
)

type apiService struct {
db *gorm.DB
cnf *config.Config
cache cache.Cacher
kv kv.KV
worker *tgc.BotWorker
db *gorm.DB
cnf *config.Config
cache cache.Cacher
kv kv.KV
worker *tgc.BotWorker
middlewares []telegram.Middleware
}

func (a *apiService) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
Expand Down Expand Up @@ -50,7 +52,8 @@ func NewApiService(db *gorm.DB,
cache cache.Cacher,
kv kv.KV,
worker *tgc.BotWorker) *apiService {
return &apiService{db: db, cnf: cnf, cache: cache, kv: kv, worker: worker}
return &apiService{db: db, cnf: cnf, cache: cache, kv: kv, worker: worker,
middlewares: tgc.NewMiddleware(&cnf.TG, tgc.WithFloodWait(), tgc.WithRateLimit())}
}

type extendedService struct {
Expand Down
Loading

0 comments on commit a296b36

Please sign in to comment.