diff --git a/hub.go b/hub.go index 3b91fefd..bb87fd0e 100644 --- a/hub.go +++ b/hub.go @@ -120,6 +120,7 @@ func init() { } type Hub struct { + version string events AsyncEvents upgrader websocket.Upgrader cookie *securecookie.SecureCookie @@ -300,7 +301,8 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer } hub := &Hub{ - events: events, + version: version, + events: events, upgrader: websocket.Upgrader{ ReadBufferSize: websocketReadBufferSize, WriteBufferSize: websocketWriteBufferSize, @@ -2626,7 +2628,11 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { addr := h.getRealUserIP(r) agent := r.Header.Get("User-Agent") - conn, err := h.upgrader.Upgrade(w, r, nil) + header := http.Header{} + header.Set("Server", "nextcloud-spreed-signaling/"+h.version) + header.Set("X-Spreed-Signaling-Features", strings.Join(h.info.Features, ", ")) + + conn, err := h.upgrader.Upgrade(w, r, header) if err != nil { log.Printf("Could not upgrade request from %s: %s", addr, err) return diff --git a/hub_test.go b/hub_test.go index 78e0be2c..e29193b3 100644 --- a/hub_test.go +++ b/hub_test.go @@ -793,6 +793,46 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { return &wg } +func TestWebsocketFeatures(t *testing.T) { + t.Parallel() + CatchLogForTest(t) + _, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + conn, response, err := websocket.DefaultDialer.DialContext(ctx, getWebsocketUrl(server.URL), nil) + if err != nil { + t.Fatal(err) + } + defer conn.Close() // nolint + + if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling/") { + t.Errorf("expected valid server header, got \"%s\"", server) + } + features := response.Header.Get("X-Spreed-Signaling-Features") + featuresList := make(map[string]bool) + for _, f := range strings.Split(features, ",") { + f = strings.TrimSpace(f) + if f != "" { + if _, found := featuresList[f]; found { + t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features) + } + featuresList[f] = true + } + } + if len(featuresList) <= 1 { + t.Errorf("expected valid features header, got \"%s\"", features) + } + if _, found := featuresList["hello-v2"]; !found { + t.Errorf("expected feature \"hello-v2\", got \"%s\"", features) + } + + if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil { + t.Errorf("could not write close message: %s", err) + } +} + func TestInitialWelcome(t *testing.T) { t.Parallel() CatchLogForTest(t) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index b59d5af9..68b77e25 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -69,6 +69,14 @@ const ( maxTokenAge = 5 * time.Minute remotePublisherTimeout = 5 * time.Second + + ProxyFeatureRemoteStreams = "remote-streams" +) + +var ( + defaultProxyFeatures = []string{ + ProxyFeatureRemoteStreams, + } ) type ContextKey string @@ -93,6 +101,7 @@ type ProxyServer struct { version string country string welcomeMessage string + welcomeMsg *signaling.WelcomeServerMessage config *goconf.ConfigFile url string @@ -314,7 +323,12 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* version: version, country: country, welcomeMessage: string(welcomeMessage) + "\n", - config: config, + welcomeMsg: &signaling.WelcomeServerMessage{ + Version: version, + Country: country, + Features: defaultProxyFeatures, + }, + config: config, shutdownChannel: make(chan struct{}), @@ -611,7 +625,10 @@ func (s *ProxyServer) welcomeHandler(w http.ResponseWriter, r *http.Request) { func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { addr := signaling.GetRealUserIP(r, s.trustedProxies.Load()) - conn, err := s.upgrader.Upgrade(w, r, nil) + header := http.Header{} + header.Set("Server", "nextcloud-spreed-signaling-proxy/"+s.version) + header.Set("X-Spreed-Signaling-Features", strings.Join(s.welcomeMsg.Features, ", ")) + conn, err := s.upgrader.Upgrade(w, r, header) if err != nil { log.Printf("Could not upgrade request from %s: %s", addr, err) return @@ -760,10 +777,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { Hello: &signaling.HelloProxyServerMessage{ Version: signaling.HelloVersionV1, SessionId: session.PublicId(), - Server: &signaling.WelcomeServerMessage{ - Version: s.version, - Country: s.country, - }, + Server: s.welcomeMsg, }, } client.SendMessage(response) diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 25a9a579..9af71c7d 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -22,18 +22,22 @@ package main import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "net" + "net/http/httptest" "os" + "strings" "testing" "time" "github.com/dlintw/goconf" "github.com/golang-jwt/jwt/v4" "github.com/gorilla/mux" + "github.com/gorilla/websocket" signaling "github.com/strukturag/nextcloud-spreed-signaling" ) @@ -42,12 +46,22 @@ const ( TokenIdForTest = "foo" ) -func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey) { +func getWebsocketUrl(url string) string { + if strings.HasPrefix(url, "http://") { + return "ws://" + url[7:] + "/proxy" + } else if strings.HasPrefix(url, "https://") { + return "wss://" + url[8:] + "/proxy" + } else { + panic("Unsupported URL: " + url) + } +} + +func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httptest.Server) { tempdir := t.TempDir() - var server *ProxyServer + var proxy *ProxyServer t.Cleanup(func() { - if server != nil { - server.Stop() + if proxy != nil { + proxy.Stop() } }) @@ -87,15 +101,21 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey) { config := goconf.NewConfigFile() config.AddOption("tokens", TokenIdForTest, pubkey.Name()) - if server, err = NewProxyServer(r, "0.0", config); err != nil { - t.Fatalf("could not create server: %s", err) + if proxy, err = NewProxyServer(r, "0.0", config); err != nil { + t.Fatalf("could not create proxy server: %s", err) } - return server, key + + server := httptest.NewServer(r) + t.Cleanup(func() { + server.Close() + }) + + return proxy, key, server } func TestTokenValid(t *testing.T) { signaling.CatchLogForTest(t) - server, key := newProxyServerForTest(t) + proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -113,7 +133,7 @@ func TestTokenValid(t *testing.T) { Version: "1.0", Token: tokenString, } - session, err := server.NewSession(hello) + session, err := proxy.NewSession(hello) if session != nil { defer session.Close() } else if err != nil { @@ -123,7 +143,7 @@ func TestTokenValid(t *testing.T) { func TestTokenNotSigned(t *testing.T) { signaling.CatchLogForTest(t) - server, _ := newProxyServerForTest(t) + proxy, _, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -141,7 +161,7 @@ func TestTokenNotSigned(t *testing.T) { Version: "1.0", Token: tokenString, } - session, err := server.NewSession(hello) + session, err := proxy.NewSession(hello) if session != nil { defer session.Close() t.Errorf("should not have created session") @@ -152,7 +172,7 @@ func TestTokenNotSigned(t *testing.T) { func TestTokenUnknown(t *testing.T) { signaling.CatchLogForTest(t) - server, key := newProxyServerForTest(t) + proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -170,7 +190,7 @@ func TestTokenUnknown(t *testing.T) { Version: "1.0", Token: tokenString, } - session, err := server.NewSession(hello) + session, err := proxy.NewSession(hello) if session != nil { defer session.Close() t.Errorf("should not have created session") @@ -181,7 +201,7 @@ func TestTokenUnknown(t *testing.T) { func TestTokenInFuture(t *testing.T) { signaling.CatchLogForTest(t) - server, key := newProxyServerForTest(t) + proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -199,7 +219,7 @@ func TestTokenInFuture(t *testing.T) { Version: "1.0", Token: tokenString, } - session, err := server.NewSession(hello) + session, err := proxy.NewSession(hello) if session != nil { defer session.Close() t.Errorf("should not have created session") @@ -210,7 +230,7 @@ func TestTokenInFuture(t *testing.T) { func TestTokenExpired(t *testing.T) { signaling.CatchLogForTest(t) - server, key := newProxyServerForTest(t) + proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -228,7 +248,7 @@ func TestTokenExpired(t *testing.T) { Version: "1.0", Token: tokenString, } - session, err := server.NewSession(hello) + session, err := proxy.NewSession(hello) if session != nil { defer session.Close() t.Errorf("should not have created session") @@ -271,3 +291,39 @@ func TestPublicIPs(t *testing.T) { } } } + +func TestWebsocketFeatures(t *testing.T) { + signaling.CatchLogForTest(t) + _, _, server := newProxyServerForTest(t) + + conn, response, err := websocket.DefaultDialer.DialContext(context.Background(), getWebsocketUrl(server.URL), nil) + if err != nil { + t.Fatal(err) + } + defer conn.Close() // nolint + + if server := response.Header.Get("Server"); !strings.HasPrefix(server, "nextcloud-spreed-signaling-proxy/") { + t.Errorf("expected valid server header, got \"%s\"", server) + } + features := response.Header.Get("X-Spreed-Signaling-Features") + featuresList := make(map[string]bool) + for _, f := range strings.Split(features, ",") { + f = strings.TrimSpace(f) + if f != "" { + if _, found := featuresList[f]; found { + t.Errorf("duplicate feature id \"%s\" in \"%s\"", f, features) + } + featuresList[f] = true + } + } + if len(featuresList) == 0 { + t.Errorf("expected valid features header, got \"%s\"", features) + } + if _, found := featuresList["remote-streams"]; !found { + t.Errorf("expected feature \"remote-streams\", got \"%s\"", features) + } + + if err := conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}); err != nil { + t.Errorf("could not write close message: %s", err) + } +}