Skip to content

Commit

Permalink
Add configurable trusted slurp domains (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
whyrusleeping authored Nov 20, 2023
2 parents 76185bf + 9f01575 commit 9e0e90f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 18 deletions.
13 changes: 8 additions & 5 deletions bgs/bgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,23 +596,26 @@ func (bgs *BGS) EventsHandler(c echo.Context) error {
consumerID := bgs.registerConsumer(&consumer)
defer bgs.cleanupConsumer(consumerID)

log.Infow("new consumer",
logger := log.With(
"consumer_id", consumerID,
"remote_addr", consumer.RemoteAddr,
"user_agent", consumer.UserAgent,
"cursor", since,
"consumer_id", consumerID,
)

logger.Infow("new consumer", "cursor", since)

header := events.EventHeader{Op: events.EvtKindMessage}
for {
select {
case evt, ok := <-evts:
if !ok {
logger.Error("event stream closed unexpectedly")
return nil
}

wc, err := conn.NextWriter(websocket.BinaryMessage)
if err != nil {
log.Errorf("failed to get next writer: %s", err)
logger.Errorf("failed to get next writer: %s", err)
return err
}

Expand Down Expand Up @@ -650,7 +653,7 @@ func (bgs *BGS) EventsHandler(c echo.Context) error {
}

if err := wc.Close(); err != nil {
log.Warnf("failed to flush-close our event write: %s", err)
logger.Warnf("failed to flush-close our event write: %s", err)
return nil
}

Expand Down
80 changes: 79 additions & 1 deletion bgs/fedmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math/rand"
"strings"
"sync"
"time"

Expand All @@ -16,6 +17,7 @@ import (
"golang.org/x/time/rate"

"github.com/gorilla/websocket"
pq "github.com/lib/pq"
"gorm.io/gorm"
)

Expand All @@ -36,6 +38,7 @@ type Slurper struct {
DefaultCrawlLimit rate.Limit

newSubsDisabled bool
trustedDomains []string

shutdownChan chan bool
shutdownResult chan []error
Expand Down Expand Up @@ -171,6 +174,7 @@ func (s *Slurper) loadConfig() error {
}

s.newSubsDisabled = sc.NewSubsDisabled
s.trustedDomains = sc.TrustedDomains

return nil
}
Expand All @@ -179,6 +183,7 @@ type SlurpConfig struct {
gorm.Model

NewSubsDisabled bool
TrustedDomains pq.StringArray `gorm:"type:text[]"`
}

func (s *Slurper) SetNewSubsDisabled(dis bool) error {
Expand All @@ -199,13 +204,86 @@ func (s *Slurper) GetNewSubsDisabledState() bool {
return s.newSubsDisabled
}

func (s *Slurper) AddTrustedDomain(domain string) error {
s.lk.Lock()
defer s.lk.Unlock()

if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_append(trusted_domains, ?)", domain)).Error; err != nil {
return err
}

s.trustedDomains = append(s.trustedDomains, domain)
return nil
}

func (s *Slurper) RemoveTrustedDomain(domain string) error {
s.lk.Lock()
defer s.lk.Unlock()

if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_remove(trusted_domains, ?)", domain)).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return err
}

for i, d := range s.trustedDomains {
if d == domain {
s.trustedDomains = append(s.trustedDomains[:i], s.trustedDomains[i+1:]...)
break
}
}

return nil
}

func (s *Slurper) SetTrustedDomains(domains []string) error {
s.lk.Lock()
defer s.lk.Unlock()

if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", domains).Error; err != nil {
return err
}

s.trustedDomains = domains
return nil
}

func (s *Slurper) GetTrustedDomains() []string {
s.lk.Lock()
defer s.lk.Unlock()
return s.trustedDomains
}

var ErrNewSubsDisabled = fmt.Errorf("new subscriptions temporarily disabled")

// Checks whether a host is allowed to be subscribed to
// must be called with the slurper lock held
func (s *Slurper) canSlurpHost(host string) bool {
// Check if the host is a trusted domain
for _, d := range s.trustedDomains {
// If the domain starts with a *., it's a wildcard
if strings.HasPrefix(d, "*.") {
// Cut off the * so we have .domain.com
if strings.HasSuffix(host, strings.TrimPrefix(d, "*")) {
return true
}
} else {
if host == d {
return true
}
}
}

return !s.newSubsDisabled
}

func (s *Slurper) SubscribeToPds(ctx context.Context, host string, reg bool) error {
// TODO: for performance, lock on the hostname instead of global
s.lk.Lock()
defer s.lk.Unlock()
if s.newSubsDisabled {

if !s.canSlurpHost(host) {
return ErrNewSubsDisabled
}

Expand Down
45 changes: 33 additions & 12 deletions bgs/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"

Expand All @@ -17,7 +18,6 @@ import (
"github.com/bluesky-social/indigo/mst"
"gorm.io/gorm"

"github.com/bluesky-social/indigo/util"
"github.com/bluesky-social/indigo/xrpc"
"github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
Expand Down Expand Up @@ -108,40 +108,61 @@ func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context, body *comatp
return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname")
}

if strings.HasPrefix(host, "https://") || strings.HasPrefix(host, "http://") {
return echo.NewHTTPError(http.StatusBadRequest, "must pass domain without protocol scheme")
if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") {
if s.ssl {
host = "https://" + host
} else {
host = "http://" + host
}
}

norm, err := util.NormalizeHostname(host)
u, err := url.Parse(host)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "failed to normalize hostname")
return echo.NewHTTPError(http.StatusBadRequest, "failed to parse hostname")
}

if u.Scheme == "http" && s.ssl {
return echo.NewHTTPError(http.StatusBadRequest, "this server requires https")
}

if u.Scheme == "https" && !s.ssl {
return echo.NewHTTPError(http.StatusBadRequest, "this server does not support https")
}

if u.Path != "" {
return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without path")
}

if u.Query().Encode() != "" {
return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without query")
}

host = u.Host // potentially hostname:port

banned, err := s.domainIsBanned(ctx, host)
if banned {
return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned")
}

log.Warnf("TODO: better host validation for crawl requests")

clientHost := fmt.Sprintf("%s://%s", u.Scheme, host)

c := &xrpc.Client{
Host: "https://" + host,
Host: clientHost,
Client: http.DefaultClient, // not using the client that auto-retries
}

if !s.ssl {
c.Host = "http://" + host
}

desc, err := atproto.ServerDescribeServer(ctx, c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "requested host failed to respond to describe request")
errMsg := fmt.Sprintf("requested host (%s) failed to respond to describe request", clientHost)
return echo.NewHTTPError(http.StatusBadRequest, errMsg)
}

// Maybe we could do something with this response later
_ = desc

return s.slurper.SubscribeToPds(ctx, norm, true)
return s.slurper.SubscribeToPds(ctx, host, true)
}

func (s *BGS) handleComAtprotoSyncNotifyOfUpdate(ctx context.Context, body *comatprototypes.SyncNotifyOfUpdate_Input) error {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ require (
github.com/labstack/echo/v4 v4.11.1
github.com/labstack/gommon v0.4.0
github.com/lestrrat-go/jwx/v2 v2.0.12
github.com/lib/pq v1.10.9
github.com/minio/sha256-simd v1.0.0
github.com/mitchellh/go-homedir v1.1.0
github.com/mr-tron/base58 v1.2.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ github.com/lestrrat-go/jwx/v2 v2.0.12/go.mod h1:Mq4KN1mM7bp+5z/W5HS8aCNs5RKZ911G
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM=
github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8=
github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg=
Expand Down

0 comments on commit 9e0e90f

Please sign in to comment.