diff --git a/broker/broker.go b/broker/broker.go index 928af664..d4268ee3 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -26,8 +26,32 @@ type Cacheable interface { ToCacheEntry() ([]byte, error) } +type PresenceInfo struct { + // Total number of present clients (uniq) + Total int + // Presence records + Records []interface{} +} + +// We can extend the presence read functionality in the future +// (e.g., add pagination, filtering, etc.) +type PresenceInfoOptions struct { + ReturnRecords bool +} + +type PresenceInfoOption func(*PresenceInfoOptions) + +func WithPresenceInfoOptions(opts *PresenceInfoOptions) PresenceInfoOption { + return func(o *PresenceInfoOptions) { + if opts != nil { + *o = *opts + } + } +} + // Broker is responsible for: // - Managing streams history. +// - Managing presence information. // - Keeping client states for recovery. // - Distributing broadcasts across nodes. // @@ -56,6 +80,17 @@ type Broker interface { RestoreSession(from string) ([]byte, error) // Marks session as finished (for cache expiration) FinishSession(sid string) error + + // Adds a new presence record for the stream. Returns true if that's the first + // presence record for the presence ID (pid, a unique user presence identifier). + PresenceAdd(stream string, sid string, pid string, info interface{}) error + + // Removes a presence record for the stream. Returns true if that was the last + // record for the presence ID (pid). + PresenceRemove(stream string, sid string, pid string) error + + // Retrieves presence information for the stream (counts, records, etc. depending on the options) + PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) } // LocalBroker is a single-node broker that can used to store streams data locally @@ -193,3 +228,15 @@ func (LegacyBroker) RestoreSession(from string) ([]byte, error) { func (LegacyBroker) FinishSession(sid string) error { return nil } + +func (LegacyBroker) PresenceAdd(stream string, sid string, pid string, info interface{}) error { + return errors.New("presence not supported") +} + +func (LegacyBroker) PresenceRemove(stream string, sid string, pid string) error { + return errors.New("presence not supported") +} + +func (LegacyBroker) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { + return nil, errors.New("presence not supported") +} diff --git a/broker/config.go b/broker/config.go index e058147d..d0780ea6 100644 --- a/broker/config.go +++ b/broker/config.go @@ -14,6 +14,8 @@ type Config struct { HistoryLimit int `toml:"history_limit"` // Sessions cache TTL in seconds (after disconnect) SessionsTTL int64 `toml:"sessions_ttl"` + // Presence expire TTL in seconds (after disconnect) + PresenceTTL int64 `toml:"presence_ttl"` } func NewConfig() Config { @@ -24,6 +26,8 @@ func NewConfig() Config { HistoryLimit: 100, // 5 minutes by default SessionsTTL: 5 * 60, + // 15 seconds by default + PresenceTTL: 15, } } @@ -46,6 +50,9 @@ func (c Config) ToToml() string { result.WriteString("# For how long to store sessions state for resumeability (seconds)\n") result.WriteString(fmt.Sprintf("sessions_ttl = %d\n", c.SessionsTTL)) + result.WriteString("# For how long to keep presence information after session disconnect (seconds)\n") + result.WriteString(fmt.Sprintf("presence_ttl = %d\n", c.PresenceTTL)) + result.WriteString("\n") return result.String() diff --git a/broker/memory.go b/broker/memory.go index 59f6b105..f0ed9f61 100644 --- a/broker/memory.go +++ b/broker/memory.go @@ -170,6 +170,50 @@ type expireSessionEntry struct { sid string } +type presenceSessionEntry struct { + // stream -> pid + streams map[string]string + deadline int64 +} + +type presenceEntry struct { + info interface{} + sessions []string +} + +func (pe *presenceEntry) remove(sid string) bool { + i := -1 + + for idx, s := range pe.sessions { + if s == sid { + i = idx + break + } + } + + if i == -1 { + return false + } + + pe.sessions = append(pe.sessions[:i], pe.sessions[i+1:]...) + + return len(pe.sessions) == 0 +} + +type presenceState struct { + streams map[string]map[string]*presenceEntry + sessions map[string]*presenceSessionEntry + + mu sync.RWMutex +} + +func newPresenceState() *presenceState { + return &presenceState{ + streams: make(map[string]map[string]*presenceEntry), + sessions: make(map[string]*presenceSessionEntry), + } +} + type Memory struct { broadcaster Broadcaster config *Config @@ -179,6 +223,8 @@ type Memory struct { sessions map[string]*sessionEntry expireSessions []*expireSessionEntry + presence *presenceState + streamsMu sync.RWMutex sessionsMu sync.RWMutex epochMu sync.RWMutex @@ -195,6 +241,7 @@ func NewMemoryBroker(node Broadcaster, config *Config) *Memory { tracker: NewStreamsTracker(), streams: make(map[string]*memstream), sessions: make(map[string]*sessionEntry), + presence: newPresenceState(), epoch: epoch, } } @@ -376,18 +423,137 @@ func (b *Memory) RestoreSession(from string) ([]byte, error) { func (b *Memory) FinishSession(sid string) error { b.sessionsMu.Lock() - defer b.sessionsMu.Unlock() - if _, ok := b.sessions[sid]; ok { b.expireSessions = append( b.expireSessions, &expireSessionEntry{sid: sid, deadline: time.Now().Unix() + b.config.SessionsTTL}, ) } + b.sessionsMu.Unlock() + + b.presence.mu.Lock() + + if sp, ok := b.presence.sessions[sid]; ok { + sp.deadline = time.Now().Unix() + b.config.PresenceTTL + } + + b.presence.mu.Unlock() return nil } +func (b *Memory) PresenceAdd(stream string, sid string, pid string, info interface{}) error { + b.presence.mu.Lock() + defer b.presence.mu.Unlock() + + if _, ok := b.presence.streams[stream]; !ok { + b.presence.streams[stream] = make(map[string]*presenceEntry) + } + + streamPresence := b.presence.streams[stream] + + if _, ok := streamPresence[pid]; !ok { + streamPresence[pid] = &presenceEntry{ + info: info, + sessions: []string{}, + } + } + + streamSessionPresence := streamPresence[pid] + + newPresence := len(streamSessionPresence.sessions) == 0 + + streamSessionPresence.sessions = append( + streamSessionPresence.sessions, + sid, + ) + + if _, ok := b.presence.sessions[sid]; !ok { + b.presence.sessions[sid] = &presenceSessionEntry{ + streams: make(map[string]string), + } + } + + b.presence.sessions[sid].streams[stream] = pid + + if newPresence { + b.broadcaster.Broadcast(&common.StreamMessage{ + Stream: stream, + Data: common.PresenceJoinMessage(pid, info), + }) + } + + return nil +} + +func (b *Memory) PresenceRemove(stream string, sid string, pid string) error { + b.presence.mu.Lock() + defer b.presence.mu.Unlock() + + if _, ok := b.presence.streams[stream]; !ok { + return nil + } + + streamPresence := b.presence.streams[stream] + + if _, ok := streamPresence[pid]; !ok { + return nil + } + + streamSessionPresence := streamPresence[pid] + + empty := streamSessionPresence.remove(sid) + + if empty { + delete(streamPresence, pid) + } + + if len(streamPresence) == 0 { + delete(b.presence.streams, stream) + } + + if _, ok := b.presence.sessions[sid]; ok { + delete(b.presence.sessions[sid].streams, stream) + } + + if empty { + b.broadcaster.Broadcast(&common.StreamMessage{ + Stream: stream, + Data: common.PresenceLeaveMessage(pid), + }) + } + + return nil +} + +func (b *Memory) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { + options := &PresenceInfoOptions{} + for _, opt := range opts { + opt(options) + } + + b.presence.mu.RLock() + defer b.presence.mu.RUnlock() + + if _, ok := b.presence.streams[stream]; !ok { + return &PresenceInfo{Total: 0}, nil + } + + streamPresence := b.presence.streams[stream] + + info := &PresenceInfo{Total: len(streamPresence)} + + if options.ReturnRecords { + info.Records = make([]interface{}, 0, len(streamPresence)) + + for _, entry := range streamPresence { + info.Records = append(info.Records, entry.info) + } + } + + return info, nil +} + func (b *Memory) add(name string, data string) uint64 { b.streamsMu.Lock() @@ -457,4 +623,62 @@ func (b *Memory) expire() { b.expireSessions = b.expireSessions[i:] b.sessionsMu.Unlock() + + // presence expiration + b.expirePresence() +} + +func (b *Memory) expirePresence() { + b.presence.mu.Lock() + + now := time.Now().Unix() + toDelete := []string{} + + for sid, sp := range b.presence.sessions { + if sp.deadline < now { + toDelete = append(toDelete, sid) + } + } + + leaveMessages := []common.StreamMessage{} + + for _, sid := range toDelete { + entry := b.presence.sessions[sid] + + for stream, pid := range entry.streams { + if _, ok := b.presence.streams[stream]; !ok { + continue + } + + if _, ok := b.presence.streams[stream][pid]; !ok { + continue + } + + streamSessionPresence := b.presence.streams[stream][pid] + + empty := streamSessionPresence.remove(sid) + + if empty { + delete(b.presence.streams[stream], pid) + + leaveMessages = append(leaveMessages, common.StreamMessage{ + Stream: stream, + Data: common.PresenceLeaveMessage(pid), + }) + + if len(b.presence.streams[stream]) == 0 { + delete(b.presence.streams, stream) + } + } + } + + delete(b.presence.sessions, sid) + } + + b.presence.mu.Unlock() + + // TODO: batch broadcast? + for _, msg := range leaveMessages { + b.broadcaster.Broadcast(&msg) + } } diff --git a/broker/nats.go b/broker/nats.go index 46153ae5..ce2d09a5 100644 --- a/broker/nats.go +++ b/broker/nats.go @@ -482,6 +482,18 @@ func (n *NATS) Reset() error { return nil } +func (n *NATS) PresenceAdd(stream string, sid string, pid string, info interface{}) error { + return errors.New("presence not supported") +} + +func (n *NATS) PresenceRemove(stream string, sid string, pid string) error { + return errors.New("presence not supported") +} + +func (n *NATS) PresenceInfo(stream string, opts ...PresenceInfoOption) (*PresenceInfo, error) { + return nil, errors.New("presence not supported") +} + func (n *NATS) add(stream string, data string) (uint64, error) { err := n.ensureStreamExists(stream) diff --git a/common/common.go b/common/common.go index 671bbd98..2f63ec63 100644 --- a/common/common.go +++ b/common/common.go @@ -65,6 +65,11 @@ const ( HistoryConfirmedType = "confirm_history" HistoryRejectedType = "reject_history" + PresenceJoinType = "join" + PresenceLeaveType = "leave" + PresenceInfoType = "presence" + PresenceErrorType = "presence_error" + WhisperType = "whisper" ) @@ -79,7 +84,8 @@ const ( // Reserver state fields const ( - WHISPER_STREAM_STATE = "$w" + WHISPER_STREAM_STATE = "$w" + PRESENCE_STREAM_STATE = "$p" ) // SessionEnv represents the underlying HTTP connection data: @@ -217,6 +223,12 @@ func (c *ConnectResult) ToCallResult() *CallResult { return &res } +type PresenceInfo struct { + Type string `json:"type,omitempty"` + Info interface{} `json:"info,omitempty"` + ID string `json:"id"` +} + // CommandResult is a result of performing controller action, // which contains informations about streams to subscribe, // messages to sent and broadcast. @@ -228,6 +240,7 @@ type CommandResult struct { StoppedStreams []string Transmissions []string Broadcasts []*StreamMessage + Presence *PresenceInfo CState map[string]string IState map[string]string DisconnectInterest int @@ -301,6 +314,7 @@ type Message struct { Identifier string `json:"identifier"` Data interface{} `json:"data,omitempty"` History HistoryRequest `json:"history,omitempty"` + Presence *PresenceInfo `json:"presence,omitempty"` } func (m *Message) LogValue() slog.Value { @@ -464,17 +478,18 @@ func NewDisconnectMessage(reason string, reconnect bool) *DisconnectMessage { // Reply represents an outgoing client message type Reply struct { - Type string `json:"type,omitempty"` - Identifier string `json:"identifier,omitempty"` - Message interface{} `json:"message,omitempty"` - Reason string `json:"reason,omitempty"` - Reconnect bool `json:"reconnect,omitempty"` - StreamID string `json:"stream_id,omitempty"` - Epoch string `json:"epoch,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Sid string `json:"sid,omitempty"` - Restored bool `json:"restored,omitempty"` - RestoredIDs []string `json:"restored_ids,omitempty"` + Type string `json:"type,omitempty"` + Identifier string `json:"identifier,omitempty"` + Message interface{} `json:"message,omitempty"` + Presence *PresenceInfo `json:"presence,omitempty"` + Reason string `json:"reason,omitempty"` + Reconnect bool `json:"reconnect,omitempty"` + StreamID string `json:"stream_id,omitempty"` + Epoch string `json:"epoch,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Sid string `json:"sid,omitempty"` + Restored bool `json:"restored,omitempty"` + RestoredIDs []string `json:"restored_ids,omitempty"` } func (r *Reply) LogValue() slog.Value { @@ -561,3 +576,13 @@ func RejectionMessage(identifier string) string { func DisconnectionMessage(reason string, reconnect bool) string { return string(utils.ToJSON(DisconnectMessage{Type: DisconnectType, Reason: reason, Reconnect: reconnect})) } + +// PresenceJoinMessage returns a presence message for the specified event and data +func PresenceJoinMessage(id string, info interface{}) string { + return string(utils.ToJSON(Reply{Type: PresenceJoinType, Presence: &PresenceInfo{ID: id, Info: info}})) +} + +// PresenceLeaveMessage returns a presence message for the specified event and data +func PresenceLeaveMessage(id string) string { + return string(utils.ToJSON(Reply{Type: PresenceLeaveType, Presence: &PresenceInfo{ID: id}})) +} diff --git a/mocks/Broker.go b/mocks/Broker.go index 6ce183e4..ada5a4ca 100644 --- a/mocks/Broker.go +++ b/mocks/Broker.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.50.0. DO NOT EDIT. package mocks @@ -16,10 +16,14 @@ type Broker struct { mock.Mock } -// Announce provides a mock function with given fields: +// Announce provides a mock function with no fields func (_m *Broker) Announce() string { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Announce") + } + var r0 string if rf, ok := ret.Get(0).(func() string); ok { r0 = rf() @@ -34,6 +38,10 @@ func (_m *Broker) Announce() string { func (_m *Broker) CommitSession(sid string, session broker.Cacheable) error { ret := _m.Called(sid, session) + if len(ret) == 0 { + panic("no return value specified for CommitSession") + } + var r0 error if rf, ok := ret.Get(0).(func(string, broker.Cacheable) error); ok { r0 = rf(sid, session) @@ -48,6 +56,10 @@ func (_m *Broker) CommitSession(sid string, session broker.Cacheable) error { func (_m *Broker) FinishSession(sid string) error { ret := _m.Called(sid) + if len(ret) == 0 { + panic("no return value specified for FinishSession") + } + var r0 error if rf, ok := ret.Get(0).(func(string) error); ok { r0 = rf(sid) @@ -72,6 +84,10 @@ func (_m *Broker) HandleCommand(msg *common.RemoteCommandMessage) { func (_m *Broker) HistoryFrom(stream string, epoch string, offset uint64) ([]common.StreamMessage, error) { ret := _m.Called(stream, epoch, offset) + if len(ret) == 0 { + panic("no return value specified for HistoryFrom") + } + var r0 []common.StreamMessage var r1 error if rf, ok := ret.Get(0).(func(string, string, uint64) ([]common.StreamMessage, error)); ok { @@ -98,6 +114,10 @@ func (_m *Broker) HistoryFrom(stream string, epoch string, offset uint64) ([]com func (_m *Broker) HistorySince(stream string, ts int64) ([]common.StreamMessage, error) { ret := _m.Called(stream, ts) + if len(ret) == 0 { + panic("no return value specified for HistorySince") + } + var r0 []common.StreamMessage var r1 error if rf, ok := ret.Get(0).(func(string, int64) ([]common.StreamMessage, error)); ok { @@ -120,10 +140,87 @@ func (_m *Broker) HistorySince(stream string, ts int64) ([]common.StreamMessage, return r0, r1 } +// PresenceAdd provides a mock function with given fields: stream, sid, pid, info +func (_m *Broker) PresenceAdd(stream string, sid string, pid string, info interface{}) error { + ret := _m.Called(stream, sid, pid, info) + + if len(ret) == 0 { + panic("no return value specified for PresenceAdd") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, string, interface{}) error); ok { + r0 = rf(stream, sid, pid, info) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PresenceInfo provides a mock function with given fields: stream, opts +func (_m *Broker) PresenceInfo(stream string, opts ...broker.PresenceInfoOption) (*broker.PresenceInfo, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, stream) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for PresenceInfo") + } + + var r0 *broker.PresenceInfo + var r1 error + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) (*broker.PresenceInfo, error)); ok { + return rf(stream, opts...) + } + if rf, ok := ret.Get(0).(func(string, ...broker.PresenceInfoOption) *broker.PresenceInfo); ok { + r0 = rf(stream, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*broker.PresenceInfo) + } + } + + if rf, ok := ret.Get(1).(func(string, ...broker.PresenceInfoOption) error); ok { + r1 = rf(stream, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PresenceRemove provides a mock function with given fields: stream, sid, pid +func (_m *Broker) PresenceRemove(stream string, sid string, pid string) error { + ret := _m.Called(stream, sid, pid) + + if len(ret) == 0 { + panic("no return value specified for PresenceRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, string) error); ok { + r0 = rf(stream, sid, pid) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // RestoreSession provides a mock function with given fields: from func (_m *Broker) RestoreSession(from string) ([]byte, error) { ret := _m.Called(from) + if len(ret) == 0 { + panic("no return value specified for RestoreSession") + } + var r0 []byte var r1 error if rf, ok := ret.Get(0).(func(string) ([]byte, error)); ok { @@ -150,6 +247,10 @@ func (_m *Broker) RestoreSession(from string) ([]byte, error) { func (_m *Broker) Shutdown(ctx context.Context) error { ret := _m.Called(ctx) + if len(ret) == 0 { + panic("no return value specified for Shutdown") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context) error); ok { r0 = rf(ctx) @@ -164,6 +265,10 @@ func (_m *Broker) Shutdown(ctx context.Context) error { func (_m *Broker) Start(done chan error) error { ret := _m.Called(done) + if len(ret) == 0 { + panic("no return value specified for Start") + } + var r0 error if rf, ok := ret.Get(0).(func(chan error) error); ok { r0 = rf(done) @@ -178,6 +283,10 @@ func (_m *Broker) Start(done chan error) error { func (_m *Broker) Subscribe(stream string) string { ret := _m.Called(stream) + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + var r0 string if rf, ok := ret.Get(0).(func(string) string); ok { r0 = rf(stream) @@ -192,6 +301,10 @@ func (_m *Broker) Subscribe(stream string) string { func (_m *Broker) Unsubscribe(stream string) string { ret := _m.Called(stream) + if len(ret) == 0 { + panic("no return value specified for Unsubscribe") + } + var r0 string if rf, ok := ret.Get(0).(func(string) string); ok { r0 = rf(stream) @@ -202,13 +315,12 @@ func (_m *Broker) Unsubscribe(stream string) string { return r0 } -type mockConstructorTestingTNewBroker interface { +// NewBroker creates a new instance of Broker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewBroker(t interface { mock.TestingT Cleanup(func()) -} - -// NewBroker creates a new instance of Broker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewBroker(t mockConstructorTestingTNewBroker) *Broker { +}) *Broker { mock := &Broker{} mock.Mock.Test(t) diff --git a/node/node.go b/node/node.go index a1f0a7b3..b363d9d4 100644 --- a/node/node.go +++ b/node/node.go @@ -2,6 +2,7 @@ package node import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -175,6 +176,8 @@ func (n *Node) HandleCommand(s *Session, msg *common.Message) (err error) { _, err = n.Perform(s, msg) case "history": err = n.History(s, msg) + case "presence": + err = n.Presence(s, msg) case "whisper": err = n.Whisper(s, msg) default: @@ -471,8 +474,12 @@ func (n *Node) Subscribe(s *Session, msg *common.Message) (*common.CommandResult if err := n.History(s, msg); err != nil { s.Log.Warn("couldn't retrieve history", "identifier", msg.Identifier, "error", err) } + } - return res, nil + if msg.Presence != nil { + if err := n.handlePresenceReply(s, msg.Identifier, common.PresenceJoinType, msg.Presence); err != nil { + s.Log.Warn("couldn't process presence join", "identifier", msg.Identifier, "error", err) + } } } @@ -670,6 +677,60 @@ func (n *Node) Whisper(s *Session, msg *common.Message) error { return nil } +// Presence returns the presence information for the specified identifier +func (n *Node) Presence(s *Session, msg *common.Message) error { + s.smu.Lock() + + if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { + s.smu.Unlock() + return fmt.Errorf("unknown subscription %s", msg.Identifier) + } + + // Check that the presence stream is configured (thus, the feature is enabled) + env := s.GetEnv() + if env == nil { + s.smu.Unlock() + return errors.New("session environment is missing") + } + + stream := env.GetChannelStateField(msg.Identifier, common.PRESENCE_STREAM_STATE) + + if stream == "" { + s.smu.Unlock() + return fmt.Errorf("presence stream not found for identifier: %s", msg.Identifier) + } + + s.smu.Unlock() + + var options *broker.PresenceInfoOptions + + if msg.Data != nil { + buf, err := json.Marshal(&msg.Data) + if err != nil { + json.Unmarshal(buf, &options) // nolint:errcheck + } + } + + info, err := n.broker.PresenceInfo(stream, broker.WithPresenceInfoOptions(options)) + + if err != nil { + s.Send(&common.Reply{ + Type: common.PresenceErrorType, + Identifier: msg.Identifier, + }) + + return err + } + + s.Send(&common.Reply{ + Type: common.PresenceInfoType, + Identifier: msg.Identifier, + Message: info, + }) + + return nil +} + // Broadcast message to stream (locally) func (n *Node) Broadcast(msg *common.StreamMessage) { n.metrics.CounterIncrement(metricsBroadcastMsg) @@ -820,6 +881,10 @@ func (n *Node) handleCommandReply(s *Session, msg *common.Message, reply *common } isConnectionDirty := n.handleCallReply(s, reply.ToCallResult()) + + // TODO: RPC-driven presence + // n.handlePresenceReply(s, reply.Presence) + return isDirty || isConnectionDirty } @@ -847,6 +912,38 @@ func (n *Node) handleCallReply(s *Session, reply *common.CallResult) bool { return isDirty } +func (n *Node) handlePresenceReply(s *Session, identifier string, event string, presence *common.PresenceInfo) error { + if presence == nil { + return nil + } + + // Check that the presence stream is configured (thus, the feature is enabled) + env := s.GetEnv() + if env == nil { + return errors.New("session environment is missing") + } + + stream := env.GetChannelStateField(identifier, common.PRESENCE_STREAM_STATE) + + if stream == "" { + return fmt.Errorf("presence stream not found for identifier: %s", identifier) + } + + sid := s.GetID() + + var err error + + if event == common.PresenceJoinType { // nolint:gocritic + err = n.broker.PresenceAdd(stream, sid, presence.ID, presence.Info) + } else if event == common.PresenceLeaveType { + err = n.broker.PresenceRemove(stream, sid, presence.ID) + } else { + return fmt.Errorf("unknown presence event: %s", event) + } + + return err +} + // disconnectScheduler controls how quickly to disconnect sessions type disconnectScheduler interface { // This method is called when a session is ready to be disconnected, diff --git a/streams/config.go b/streams/config.go index aff615bf..db695ef1 100644 --- a/streams/config.go +++ b/streams/config.go @@ -17,6 +17,9 @@ type Config struct { // Whisper determines if whispering is enabled for pub/sub streams Whisper bool `toml:"whisper"` + // Presence determines if presence is enabled for pub/sub streams + Presence bool `toml:"presence"` + // PubSubChannel is the channel name used for direct pub/sub PubSubChannel string `toml:"pubsub_channel"` @@ -80,6 +83,13 @@ func (c Config) ToToml() string { result.WriteString("# whisper = true\n") } + result.WriteString("# Enable presence support for pub/sub streams\n") + if c.Presence { + result.WriteString("presence = true\n") + } else { + result.WriteString("# presence = true\n") + } + result.WriteString("# Name of the channel used for pub/sub\n") result.WriteString(fmt.Sprintf("pubsub_channel = \"%s\"\n", c.PubSubChannel)) diff --git a/streams/controller.go b/streams/controller.go index 72f09139..9e297cb1 100644 --- a/streams/controller.go +++ b/streams/controller.go @@ -15,7 +15,8 @@ type SubscribeRequest struct { StreamName string `json:"stream_name"` SignedStreamName string `json:"signed_stream_name"` - whisper bool + whisper bool + presence bool } func (r *SubscribeRequest) IsPresent() bool { @@ -108,10 +109,14 @@ func (c *Controller) Subscribe(sid string, env *common.SessionEnv, ids string, i c.log.With("identifier", identifier).Debug("verified", "stream", stream) } - var state map[string]string + state := map[string]string{} if request.whisper { - state = map[string]string{common.WHISPER_STREAM_STATE: stream} + state[common.WHISPER_STREAM_STATE] = stream + } + + if request.presence { + state[common.PRESENCE_STREAM_STATE] = stream } return &common.CommandResult{ @@ -144,6 +149,7 @@ func NewStreamsController(conf *Config, l *slog.Logger) *Controller { key := conf.Secret allowPublic := conf.Public whispers := conf.Whisper + presence := conf.Presence resolver := func(identifier string) (*SubscribeRequest, error) { var request SubscribeRequest @@ -160,6 +166,10 @@ func NewStreamsController(conf *Config, l *slog.Logger) *Controller { request.whisper = true } + if presence || (request.StreamName != "") { + request.presence = true + } + return &request, nil }