diff --git a/server/package-lock.json b/server/package-lock.json index 720c9f6c01..19988632d0 100644 --- a/server/package-lock.json +++ b/server/package-lock.json @@ -9,7 +9,6 @@ "version": "3.33.0", "dependencies": { "@sentry/node": "^7.60.1", - "cookie-parser": "^1.4.6", "cross-fetch": "^3.1.8", "express": "^4.19.2", "express-prom-bundle": "^6.6.0", @@ -2992,26 +2991,6 @@ "node": ">= 0.6" } }, - "node_modules/cookie-parser": { - "version": "1.4.6", - "resolved": "https://registry.npmjs.org/cookie-parser/-/cookie-parser-1.4.6.tgz", - "integrity": "sha512-z3IzaNjdwUC2olLIB5/ITd0/setiaFMLYiZJle7xg5Fe9KWAceil7xszYfHHBtDFYLSgJduS2Ty0P1uJdPDJeA==", - "dependencies": { - "cookie": "0.4.1", - "cookie-signature": "1.0.6" - }, - "engines": { - "node": ">= 0.8.0" - } - }, - "node_modules/cookie-parser/node_modules/cookie": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.1.tgz", - "integrity": "sha512-ZwrFkGJxUR3EIoXtO+yVE69Eb7KlixbaeAWfBQB9vVsNn/o+Yw69gBWSSDK825hQNdN+wF8zELf3dFNl/kxkUA==", - "engines": { - "node": ">= 0.6" - } - }, "node_modules/cookie-signature": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", @@ -11096,22 +11075,6 @@ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.2.tgz", "integrity": "sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA==" }, - "cookie-parser": { - "version": "1.4.6", - "resolved": "https://registry.npmjs.org/cookie-parser/-/cookie-parser-1.4.6.tgz", - "integrity": "sha512-z3IzaNjdwUC2olLIB5/ITd0/setiaFMLYiZJle7xg5Fe9KWAceil7xszYfHHBtDFYLSgJduS2Ty0P1uJdPDJeA==", - "requires": { - "cookie": "0.4.1", - "cookie-signature": "1.0.6" - }, - "dependencies": { - "cookie": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.4.1.tgz", - "integrity": "sha512-ZwrFkGJxUR3EIoXtO+yVE69Eb7KlixbaeAWfBQB9vVsNn/o+Yw69gBWSSDK825hQNdN+wF8zELf3dFNl/kxkUA==" - } - } - }, "cookie-signature": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", diff --git a/server/package.json b/server/package.json index e5054a260c..3ded320d13 100644 --- a/server/package.json +++ b/server/package.json @@ -26,7 +26,6 @@ }, "dependencies": { "@sentry/node": "^7.60.1", - "cookie-parser": "^1.4.6", "cross-fetch": "^3.1.8", "express": "^4.19.2", "express-prom-bundle": "^6.6.0", diff --git a/server/src/api-client/index.ts b/server/src/api-client/index.ts index 22c35ac3dc..3beae1a12f 100644 --- a/server/src/api-client/index.ts +++ b/server/src/api-client/index.ts @@ -36,9 +36,7 @@ class APIClient { * Fetch session status * */ - async getSessionStatus( - authHeathers: Record - ): Promise { + async getSessionStatus(authHeathers: HeadersInit): Promise { const sessionsUrl = `${this.gatewayUrl}/notebooks/servers`; logger.debug(`Fetching session status.`); const options = { @@ -55,7 +53,7 @@ class APIClient { */ async kgActivationStatus( projectId: number, - authHeaders: Headers + authHeaders: HeadersInit ): Promise { const headers = new Headers(authHeaders); const activationStatusURL = `${this.gatewayUrl}/projects/${projectId}/graph/status`; diff --git a/server/src/authentication/authentication.types.ts b/server/src/authentication/authentication.types.ts new file mode 100644 index 0000000000..dd83111a67 --- /dev/null +++ b/server/src/authentication/authentication.types.ts @@ -0,0 +1,35 @@ +/*! + * Copyright 2024 - Swiss Data Science Center (SDSC) + * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and + * Eidgenössische Technische Hochschule Zürich (ETHZ). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +import express from "express"; + +export type User = AnonymousUser | LoggedInUser; + +export type AnonymousUser = { + id: ""; + anonymousId: string; +}; + +export type LoggedInUser = { + id: string; + renkuAuthToken: string; +}; + +export type RequestWithUser = express.Request & { + user?: User | null | undefined; +}; diff --git a/server/src/authentication/authenticator.ts b/server/src/authentication/authenticator.ts new file mode 100644 index 0000000000..c34055b0b0 --- /dev/null +++ b/server/src/authentication/authenticator.ts @@ -0,0 +1,127 @@ +/*! + * Copyright 2024 - Swiss Data Science Center (SDSC) + * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and + * Eidgenössische Technische Hochschule Zürich (ETHZ). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +import express from "express"; +import { JWT } from "jose"; +import { Client, Issuer } from "openid-client"; + +import config from "../config"; +import logger from "../logger"; +import { getCookieValueByName } from "../utils"; + +import { + AnonymousUser, + LoggedInUser, + RequestWithUser, + User, +} from "./authentication.types"; + +export class Authenticator { + authServerUrl: string; + issuer: Issuer; + + constructor(authServerUrl: string = config.auth.serverUrl) { + this.authServerUrl = authServerUrl; + } + + async init(): Promise { + try { + this.issuer = await Issuer.discover(this.authServerUrl); + logger.info("Authenticator initialized"); + } catch (error) { + logger.error( + "Cannot initialize the auth client. The authentication server may be down or some paramaters may be wrong. " + + "Please check the next log entry for further details." + ); + logger.error(error); + throw error; + } + return true; + } + + async authenticate({ + authHeader, + sessionId = "", + }: { + authHeader: string; + sessionId?: string; + }): Promise { + const anonUser: AnonymousUser = { + id: "", + anonymousId: sessionId ?? "", + }; + + const authToken = authHeader + .toLowerCase() + .startsWith(config.auth.authHeaderPrefix) + ? authHeader.slice(config.auth.authHeaderPrefix.length).trim() + : authHeader.trim(); + + if (!authToken) { + return anonUser; + } + + try { + const issuer = this.issuer; + if (issuer == null) { + logger.error("The authenticator is not ready."); + return anonUser; + } + + const keystore = await issuer.keystore(); + const { payload } = JWT.verify(authToken, keystore, { complete: true }); + const userId = (payload as { sub?: string })["sub"]; + if (userId) { + const user: LoggedInUser = { id: userId, renkuAuthToken: authToken }; + logger.debug(`Authentication: authenticated user ${user.id}`); + return user; + } + } catch (error) { + logger.error("Authentication failed:"); + logger.error(error); + } + return anonUser; + } + + middleware(): ( + req: RequestWithUser, + res: express.Response, + next: express.NextFunction + ) => Promise { + const authenticate: typeof this.authenticate = this.authenticate.bind(this); + + async function authenticationMiddleware( + req: RequestWithUser, + res: express.Response, + next: express.NextFunction + ) { + // Do not re-authenticate the request + if (req.user != null) { + return next(); + } + + const authHeader = req.header(config.auth.authHeaderField); + const sessionId = + getCookieValueByName(req.header("cookie"), config.auth.cookiesKey) ?? + ""; + req.user = await authenticate({ authHeader, sessionId }); + return next(); + } + return authenticationMiddleware; + } +} diff --git a/server/src/authentication/index.ts b/server/src/authentication/index.ts deleted file mode 100644 index b87911b9aa..0000000000 --- a/server/src/authentication/index.ts +++ /dev/null @@ -1,434 +0,0 @@ -/*! - * Copyright 2021 - Swiss Data Science Center (SDSC) - * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and - * Eidgenössische Technische Hochschule Zürich (ETHZ). - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import * as Sentry from "@sentry/node"; -import express from "express"; -import { Issuer, generators, Client, TokenSet } from "openid-client"; - -import config from "../config"; -import logger from "../logger"; -import { - Storage, - StorageGetOptions, - StorageSaveOptions, - TypeData, -} from "../storage"; -import { sleep } from "../utils"; -import { APIError } from "../utils/apiError"; -import { HttpStatusCode } from "../utils/baseError"; -import jwt from "jsonwebtoken"; - -const verifierSuffix = "-verifier"; -const parametersSuffix = "-parameters"; -const maxAttempts = config.auth.retryConnectionAttempts; - -type GetStorageValueReturn = { - storageKey: string; - value: string | null; -}; - -class Authenticator { - authServerUrl: string; - clientId: string; - clientSecret: string; - callbackUrl: string; - - storage: Storage; - - retryAttempt = 0; - authClient: Client; - ready = false; - private saveStorageOptions: StorageSaveOptions = { - type: TypeData.String, - }; - private getStorageOptions: StorageGetOptions = { - type: TypeData.String, - }; - - constructor( - storage: Storage, - authServerUrl: string = config.auth.serverUrl, - clientId: string = config.auth.clientId, - clientSecret: string = config.auth.clientSecret, - callbackUrl: string = config.server.url + - config.server.prefix + - config.routes.auth + - "/callback" - ) { - // Validate and save parameters - for (const param of [ - "storage", - "authServerUrl", - "clientId", - "clientSecret", - "callbackUrl", - ]) { - if (!param || !param.length) { - const newError = new Error(`The parameter "${param}" is mandatory.`); - logger.error(newError); - throw newError; - } - } - - this.storage = storage; - this.authServerUrl = authServerUrl; - this.clientId = clientId; - this.clientSecret = clientSecret; - this.callbackUrl = callbackUrl; - } - - /** - * Initialize client to interact with the authentication server. - */ - async init(): Promise { - try { - const issuer = await Issuer.discover(this.authServerUrl); - this.authClient = new issuer.Client({ - client_id: this.clientId, - client_secret: this.clientSecret, - redirect_uris: [this.callbackUrl], - response_types: ["code"], - }); - this.ready = true; - logger.info("Authenticator succesfully initialized."); - return true; - } catch (error) { - this.retryAttempt++; - logger.error( - "Cannot initialize the auth client. The authentication server may be down or some paramaters may be wrong. " + - `Attempt number ${this.retryAttempt} of ${maxAttempts} ` + - "Please check the next log entry for further details." - ); - logger.error(error); - if (this.retryAttempt < maxAttempts) { - await sleep(10); - return this.init(); - } - throw error; - } - } - - private checkInit(): boolean { - if (!this.ready) { - const newError = new Error( - "Cannot interact with the authentication server. Did you invoke `await init()`?" - ); - logger.error(newError); - throw newError; - } - return true; - } - - private getVerifierKey(sessionId: string): string { - return sessionId + verifierSuffix; - } - - private getParametersKey(sessionId: string): string { - return sessionId + parametersSuffix; - } - - /** - * Delete a value from storage - * @param storageKey - the key under which the value has been stored - * @param actionDesc - a description of the action, used for error messages - * @returns true if the operation did not fail, false if it did fail - */ - private async deleteStorageValue( - storageKey: string, - actionDesc: string - ): Promise { - const numDeleted = await this.storage.delete(storageKey); - if (numDeleted < 0) { - const errorMessage = `Could not delete ${actionDesc} from storage.`; - logger.error(errorMessage); - Sentry.captureMessage(errorMessage); - return false; - } - return true; - } - - private async getStorageValueAsString( - key: string - ): Promise { - const storageKey = `${config.auth.storagePrefix}${key}`; - const storageValue = await this.storage.get( - storageKey, - this.getStorageOptions - ); - return { storageKey, value: storageValue as string }; - } - - private async saveStorageValueAsString( - key: string, - value: string - ): Promise { - const storageKey = `${config.auth.storagePrefix}${key}`; - return await this.storage.save(storageKey, value, this.saveStorageOptions); - } - - /** - * The parameters for the redirect URL after login need to be temporarily stored. Get the parameter - * string to attach to the final login, and optionally delete the entry from the storage. - * - * @param sessionId - session id - * @param deleteAfter - boolean defaults to true - * @returns url search string, including the initial `?` - */ - async getPostLoginParametersAndDelete( - sessionId: string, - deleteAfter = true - ): Promise { - const parametersKey = this.getParametersKey(sessionId); - const { storageKey, value: parametersString } = - await this.getStorageValueAsString(parametersKey); - if (parametersString == null) return ""; - if (deleteAfter) { - await this.deleteStorageValue( - storageKey, - `login parameters for session ${sessionId}` - ); - } - return parametersString; - } - - /** - * Starts the authentication flow. It saves the code verifier and it returns the url to redirect to. - * - * @param sessionId - session id - */ - async startAuthFlow( - sessionId: string, - redirectParams: string = null - ): Promise { - // ? REF: https://darutk.medium.com/diagrams-of-all-the-openid-connect-flows-6968e3990660 - this.checkInit(); - - // create and store the verifier - const verifier = generators.codeVerifier(); - const challenge = generators.codeChallenge(verifier); - const verifierKey = this.getVerifierKey(sessionId); - if (!(await this.saveStorageValueAsString(verifierKey, verifier))) { - throw new Error("Redis not available to support auth flow."); - } - if (redirectParams) { - const parametersKey = this.getParametersKey(sessionId); - if ( - !(await this.saveStorageValueAsString(parametersKey, redirectParams)) - ) { - throw new Error("Redis not available to support auth flow."); - } - } - - // create and return the login url - const authUrl = this.authClient.authorizationUrl({ - scope: "openid profile email microprofile-jwt", - code_challenge: challenge, - code_challenge_method: "S256", - }); - - return authUrl; - } - - /** - * Starts the authentication flow. It saves the code verifier and it returns the url to redirect to. - * - * @param req - express request containing the code challange - */ - getAuthCode(req: express.Request): string { - this.checkInit(); - - // get the code param - const params = this.authClient.callbackParams(req); - if (params["code"] != null) return params["code"]; - // TODO: return error response when needed - } - - /** - * Complete the authentication flow. It cleans-up the code verifier, which is not needed anymore. - * - * @param sessionId - session id - */ - async finishAuthFlow(sessionId: string, code: string): Promise { - this.checkInit(); - - // get the verifier code and remove it from redis - const verifierKey = this.getVerifierKey(sessionId); - const { storageKey, value: verifier } = await this.getStorageValueAsString( - verifierKey - ); - if (verifier == null) { - const error = - "Code challenge not available. Are you re-loading an old page?"; - throw new APIError( - "Auth callback reloading page error", - HttpStatusCode.INTERNAL_SERVER, - error - ); - } - - await this.deleteStorageValue( - storageKey, - `cleanup verifier for ${sessionId}` - ); - - try { - const tokens = await this.authClient.callback( - this.callbackUrl, - { code }, - { code_verifier: verifier } - ); - if (tokens) return tokens; - return null; - } catch (error) { - throw new APIError( - "Error callback for Authorization Server", - HttpStatusCode.INTERNAL_SERVER, - error - ); - } - } - - /** - * Store stringified token set to the storage. - * - * @param sessionId - session id - * @param tokens - tokens object as received from the authentication server (must contain access and refresh token) - */ - async storeTokens(sessionId: string, tokens: TokenSet): Promise { - this.checkInit(); - - const result = await this.saveStorageValueAsString( - sessionId, - JSON.stringify(tokens) - ); - if (!result) { - const errorMessage = `Could not store refresh tokens for session ${sessionId}`; - logger.error(errorMessage); - Sentry.captureMessage(errorMessage); - } - return result; - } - - /** - * Get token set from the storage. - * - * @param sessionId - session id - * @param autoRefresh - automatically refresh tokens when necessary - * @returns tokens - tokens object as received from the authentication server (must contain access and refresh token) - */ - async getTokens(sessionId: string, autoRefresh = true): Promise { - this.checkInit(); - - // Get tokens from the store - const { value: stringyTokens } = await this.getStorageValueAsString( - sessionId - ); - if (stringyTokens == null) return null; - let tokens = new TokenSet(JSON.parse(stringyTokens) as TokenSet); - - const tokenExpired = this.checkTokenExpiration(tokens); - if (!tokenExpired) return tokens; - if (!autoRefresh) return null; // ? may implement something more useful once it's used - try { - tokens = await this.refreshTokens(sessionId, tokens); - } catch (error) { - if (error.toString().includes("invalid")) - logger.info(`Tokens invalid for session ${sessionId}`); - else logger.error(error); - throw error; - } - return tokens; - } - - /** - * Check tokens expiration. - * - * @param tokens - token set - * @returns number representing the TokenStatus - */ - checkTokenExpiration(tokens: TokenSet): boolean { - // add tolerance - const tokensWithTolerance = { - ...tokens, - expires_at: - tokens.expires_at - (config.auth.tokenExpirationTolerance as number), - }; - - // re-initialize the TokenSet as a proper object. - const tokensObject = new TokenSet(tokensWithTolerance); - const expired = tokensObject.expired(); - return expired; - } - - /** - * delete token set from the storage. - * - * @param sessionId - session id - * @returns true if the delete operation succeeded, false otherwise. Mind that trying to delete an - * already delete key won't make the operation fail. - */ - async deleteTokens(sessionId: string): Promise { - this.checkInit(); - return await this.deleteStorageValue( - `${config.auth.storagePrefix}${sessionId}`, - `tokens for session ${sessionId}` - ); - } - - /** - * Refresh tokens when possible. Otherwise, remove the expired/corrupted credentials. - * - * @param sessionId - session id - */ - async refreshTokens( - sessionId: string, - tokens: TokenSet = null, - removeIfFailed = true - ): Promise { - // get the tokens from the store when not provided. - if (tokens == null) { - tokens = await this.getTokens(sessionId, false); - if (tokens == null) return null; // can't refresh them if they doesn't exist - } - - const refreshedTokens = await this.authClient.refresh(tokens.refresh_token); - if (refreshedTokens != null) - await this.storeTokens(sessionId, refreshedTokens); - else if (removeIfFailed) await this.deleteTokens(sessionId); - - return refreshedTokens; - } -} - -/** - * Return user Id from token - * - * @param authHeader - jwt token using bearer schema - */ -const getUserIdFromToken = (authHeader: string): string => { - if (!authHeader) return undefined; - - const authItems = authHeader.split(" "); - - if (authItems.length <= 1) return undefined; - - const user = jwt.decode(authItems[1]); - return (user.sub as string) || undefined; -}; - -export { Authenticator, getUserIdFromToken }; diff --git a/server/src/authentication/middleware.ts b/server/src/authentication/middleware.ts deleted file mode 100644 index e96347b4d0..0000000000 --- a/server/src/authentication/middleware.ts +++ /dev/null @@ -1,119 +0,0 @@ -/*! - * Copyright 2021 - Swiss Data Science Center (SDSC) - * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and - * Eidgenössische Technische Hochschule Zürich (ETHZ). - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import express from "express"; -import { TokenSet } from "openid-client"; - -import config from "../config"; -import logger from "../logger"; -import { Authenticator } from "./index"; -import { getOrCreateSessionId } from "./routes"; -import { serializeCookie } from "../utils"; -import { WsMessage } from "../websocket/WsMessages"; - -/** - * Add the authorization header for invoking gateway APIs as an authenticated renku user. - * - * @param res - express response - * @param accessToken - valid access token. - */ -function addAuthToken(req: express.Request, accessToken: string): void { - const value = config.auth.authHeaderPrefix + accessToken; - req.headers[config.auth.authHeaderField] = value; -} - -/** - * Add the nonymous header for invoking gateway APIs as an anonymous renku user. - * - * @param req - express response - * @param value - uid for the anonamous user. - */ -function addAnonymousToken(req: express.Request, value: string): void { - req.headers[config.auth.cookiesAnonymousKey] = value; -} - -/** - * Add the invalid credentials header to signal the need to re-authenticate. - * - * @param res - express response - */ -function addAuthInvalid(req: express.Request): void { - req.headers[config.auth.invalidHeaderField] = - config.auth.invalidHeaderExpired; -} - -function renkuAuth(authenticator: Authenticator) { - return async ( - req: express.Request, - res: express.Response, - next: express.NextFunction - ): Promise => { - // get or create session - const sessionId = getOrCreateSessionId(req, res); - let tokens: TokenSet; - try { - tokens = await authenticator.getTokens(sessionId, true); - } catch (error) { - const stringyError = error.toString(); - const expired = - stringyError.includes("expired") || stringyError.includes("invalid"); - if (expired) { - logger.info(`Adding token expirations info for session ${sessionId}`); - addAuthInvalid(req); - } else { - throw error; - } - } - - if (tokens) addAuthToken(req, tokens.access_token); - else addAnonymousToken(req, sessionId); - - next(); - }; -} - -async function wsRenkuAuth( - authenticator: Authenticator, - sessionId: string -): Promise> { - let tokens: TokenSet; - try { - tokens = await authenticator.getTokens(sessionId, true); - } catch (error) { - const stringyError = error.toString(); - - const expired = - stringyError.includes("expired") || stringyError.includes("invalid"); - if (expired) throw new Error("expired"); - throw error; - } - - if (tokens) { - const value = config.auth.authHeaderPrefix + tokens.access_token; - return { [config.auth.authHeaderField]: value }; - } - - // Anonymous users - const fullAnonId = config.auth.anonPrefix + sessionId; - const newCookies: Array = [ - serializeCookie(config.auth.cookiesAnonymousKey, fullAnonId), - ]; - return { cookie: newCookies.join("; ") }; -} - -export { renkuAuth, addAuthToken, wsRenkuAuth }; diff --git a/server/src/authentication/routes.ts b/server/src/authentication/routes.ts deleted file mode 100644 index 44b4a4a2b4..0000000000 --- a/server/src/authentication/routes.ts +++ /dev/null @@ -1,126 +0,0 @@ -/*! - * Copyright 2021 - Swiss Data Science Center (SDSC) - * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and - * Eidgenössische Technische Hochschule Zürich (ETHZ). - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import express from "express"; -import { v4 as uuidv4 } from "uuid"; - -import config from "../config"; -import { Authenticator } from "./index"; - -/** - * Get the session id. If not availabe, one is created and set on the response cookies - * - * @param req - express request - * @param res - express response - * @return session id - */ -function getOrCreateSessionId( - req: express.Request, - res: express.Response, - serverPrefix: string = config.server.prefix -): string { - const cookiesKey = config.auth.cookiesKey; - let sessionId: string; - if (req.cookies[cookiesKey] != null) { - sessionId = req.cookies[cookiesKey]; - } else { - sessionId = uuidv4(); - res.cookie(cookiesKey, sessionId, { - secure: true, - httpOnly: true, - path: serverPrefix, - }); - } - return sessionId; -} - -/** - * Extract and return the search string (i.e. the query parameters in the form `?anyvalue`). - * - * @param req - express request containing the url - * @returns search string - */ -function getStringyParams(req: express.Request): string { - const fullUrl = req.url.toLowerCase().startsWith("http") - ? req.url - : config.server.url + req.url; - const urlObject = new URL(fullUrl); - return urlObject.search; -} - -function registerAuthenticationRoutes( - app: express.Application, - authenticator: Authenticator -): void { - const authPrefix = config.server.prefix + config.routes.auth; - - app.get(authPrefix + "/login", async (req, res, next) => { - try { - // start the login using the code flow, preserving query params for later. - const sessionId = getOrCreateSessionId(req, res); - const inputParams = getStringyParams(req); - const loginCodeUrl = await authenticator.startAuthFlow( - sessionId, - inputParams - ); - - res.redirect(loginCodeUrl); - } catch (error) { - next(error); - } - }); - - app.get(authPrefix + "/callback", async (req, res, next) => { - try { - // finish the auth flow, exchanging the auth code with the token set. - const sessionId = getOrCreateSessionId(req, res); - const code = authenticator.getAuthCode(req); - const tokens = await authenticator.finishAuthFlow(sessionId, code); - await authenticator.storeTokens(sessionId, tokens); - - // create the login url, adding the original query params. - const originalParameters = - await authenticator.getPostLoginParametersAndDelete(sessionId); - const backendLoginUrl = - config.deployment.gatewayLoginUrl + originalParameters; - - // ? Do I need to set the access token here? Will this be needed when removing the `session` cookie from gateway? - // ? res.set(config.auth.authHeaderField, config.auth.authHeaderPrefix + tokens["access_token"]); - res.redirect(backendLoginUrl); - } catch (error) { - next(error); - } - }); - - app.get(authPrefix + "/logout", async (req, res, next) => { - try { - // delete token set - const sessionId = getOrCreateSessionId(req, res); - await authenticator.deleteTokens(sessionId); - - // create the logout url - const inputParams = getStringyParams(req); - const backendLoginUrl = config.deployment.gatewayLogoutUrl + inputParams; - res.redirect(backendLoginUrl); - } catch (error) { - next(error); - } - }); -} - -export { registerAuthenticationRoutes, getOrCreateSessionId }; diff --git a/server/src/config.ts b/server/src/config.ts index 5aed94eb14..9e03087eec 100644 --- a/server/src/config.ts +++ b/server/src/config.ts @@ -34,14 +34,6 @@ const gatewayUrl = process.env.GATEWAY_URL || urlJoin(SERVER.url ?? "", "/api"); const DEPLOYMENT = { gatewayUrl, - gatewayLoginUrl: urlJoin( - gatewayUrl, - process.env.GATEWAY_LOGIN_PATH || "/auth/login" - ), - gatewayLogoutUrl: urlJoin( - gatewayUrl, - process.env.GATEWAY_LOGOUT_PATH || "/auth/logout" - ), }; const SENTRY = { @@ -59,18 +51,9 @@ const SENTRY = { const AUTHENTICATION = { serverUrl: process.env.AUTH_SERVER_URL || SERVER.url + "/auth/realms/Renku", - clientId: process.env.AUTH_CLIENT_ID || "renku-ui", - clientSecret: process.env.AUTH_CLIENT_SECRET, - tokenExpirationTolerance: convertType(process.env.AUTH_TOKEN_TOLERANCE) || 10, - cookiesKey: "ui-server-session", - cookiesAnonymousKey: "anon-id", - anonPrefix: "anon-", // ? this MUST start with a letter to prevent k8s limitations + cookiesKey: "_renku_session", authHeaderField: "Authorization", authHeaderPrefix: "bearer ", - invalidHeaderField: "ui-server-auth", - invalidHeaderExpired: "expired", - retryConnectionAttempts: 10, - storagePrefix: "AUTH_", }; const REDIS = { @@ -92,7 +75,6 @@ const DATA = { searchStoragePrefix: StoragePrefix.LAST_SEARCHES, searchDefaultLength: 10, projectsDefaultLength: 20, - userSessionsPrefix: "SESSIONS_", }; const WEBSOCKET = { diff --git a/server/src/index.ts b/server/src/index.ts index 4fa3703f47..bd3c7f3619 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -16,14 +16,11 @@ * limitations under the License. */ -import cookieParser from "cookie-parser"; import express from "express"; import morgan from "morgan"; import ws from "ws"; import APIClient from "./api-client"; -import { Authenticator } from "./authentication"; -import { registerAuthenticationRoutes } from "./authentication/routes"; import config from "./config"; import logger from "./logger"; import routes from "./routes"; @@ -36,6 +33,7 @@ import { initializeSentry, } from "./utils/sentry/sentry"; import { configureWebsocket } from "./websocket"; +import { Authenticator } from "./authentication/authenticator"; const app = express(); const port = config.server.port; @@ -70,22 +68,17 @@ initializePrometheus(app); const storage = new RedisStorage(); // configure authenticator -const authenticator = new Authenticator(storage); +const authenticator = new Authenticator(); const authPromise = authenticator.init(); -authPromise.then(() => { - logger.info("Authenticator started"); - - registerAuthenticationRoutes(app, authenticator); - // The error handler middleware is needed here because the registration of authentication - // routes is asynchronous and the middleware has to be registered after them - app.use(errorHandlerMiddleware); +app.use(authenticator.middleware()); +authPromise.catch(() => { + shutdown(); }); -// register middlewares -app.use(cookieParser()); - // register routes -routes.register(app, prefix, authenticator, storage); +routes.register(app, prefix, storage); + +app.use(errorHandlerMiddleware); // start the Express server const server = app.listen(port, () => { @@ -106,7 +99,7 @@ function createWsServer() { addWebSocketServerContext(wsServer); authPromise.then(() => { logger.info("Configuring WebSocket server"); - configureWebsocket(wsServer, authenticator, storage, apiClient); + configureWebsocket(wsServer, storage, apiClient); }); return wsServer; } diff --git a/server/src/routes/apis.ts b/server/src/routes/apis.ts index 340d09e97b..1ed9bf1b6e 100644 --- a/server/src/routes/apis.ts +++ b/server/src/routes/apis.ts @@ -16,20 +16,20 @@ * limitations under the License. */ -import express from "express"; import fetch from "cross-fetch"; +import express from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; +import type { RequestWithUser } from "../authentication/authentication.types"; import config from "../config"; import logger from "../logger"; -import { Authenticator } from "../authentication"; -import { renkuAuth } from "../authentication/middleware"; +import { Storage } from "../storage"; import { getCookieValueByName, serializeCookie } from "../utils"; -import { validateCSP } from "../utils/url"; import { lastProjectsMiddleware } from "../utils/middlewares/lastProjectsMiddleware"; import { lastSearchQueriesMiddleware } from "../utils/middlewares/lastSearchQueriesMiddleware"; import uploadFileMiddleware from "../utils/middlewares/uploadFileMiddleware"; -import { Storage } from "../storage"; +import { validateCSP } from "../utils/url"; + import { CheckURLResponse } from "./apis.interfaces"; import { getUserData } from "./helperFunctions"; @@ -65,54 +65,30 @@ const proxyMiddleware = createProxyMiddleware({ } } } - // add anon-id to cookies when the proper header is set. - const anonId = clientReq.getHeader(config.auth.cookiesAnonymousKey); - if (anonId) { - // ? the anon-id MUST start with a letter to prevent k8s limitations - const fullAnonId = config.auth.anonPrefix + anonId; - newCookies.push( - serializeCookie(config.auth.cookiesAnonymousKey, fullAnonId) + if (newCookies.length > 0) { + clientReq.setHeader("cookie", newCookies.join("; ")); + } + + // Swap headers for the knowledge graph API + const gitlabAccessToken = clientReq.getHeader("Gitlab-Access-Token"); + if (gitlabAccessToken) { + clientReq.setHeader( + config.auth.authHeaderField, + `${config.auth.authHeaderPrefix}${gitlabAccessToken}` ); + } else { + clientReq.removeHeader(config.auth.authHeaderField); } - if (newCookies.length > 0) - clientReq.setHeader("cookie", newCookies.join("; ")); }, onProxyRes: (clientRes, req: express.Request, res: express.Response) => { // Add CORS for sentry res.setHeader("Access-Control-Allow-Headers", "sentry-trace"); - - // handle auth expiration -- we change the response status to avoid browser caching - const expHeader = req.get(config.auth.invalidHeaderField); - if (expHeader != null) { - clientRes.headers[config.auth.invalidHeaderField] = expHeader; - if (expHeader === config.auth.invalidHeaderExpired) { - // We return a different response to prevent side effects from caching mechanism on 30x responses - logger.warn( - `Authentication expired when trying to reach ${req.originalUrl}. Attaching auth headers.` - ); - res.status(500); - res.setHeader(config.auth.invalidHeaderField, expHeader); - res.json({ error: "Invalid authentication tokens" }); - } - } - - // Prevent gateway from setting anon-id cookies. That's not needed in the UI anymore - const setCookie = null ?? clientRes.headers["set-cookie"]; - if (setCookie == null || !setCookie.length) return; - const allowedSetCookie = []; - for (const cookie of setCookie) { - if (!cookie.startsWith(config.auth.cookiesAnonymousKey)) - allowedSetCookie.push(cookie); - } - if (!allowedSetCookie.length) clientRes.headers["set-cookie"] = null; - else clientRes.headers["set-cookie"] = allowedSetCookie; }, }); function registerApiRoutes( app: express.Application, prefix: string, - authenticator: Authenticator, storage: Storage ): void { // Locally defined APIs @@ -156,17 +132,17 @@ function registerApiRoutes( app.get( prefix + "/last-projects/:length", - renkuAuth(authenticator), - async (req, res) => { - const token = req.headers[config.auth.authHeaderField] as string; - if (!token) { + async (req: RequestWithUser, res) => { + const user = req.user; + if (!user?.id) { res.json({ error: "User not authenticated" }); return; } + const length = parseInt(req.params["length"]) || 0; const data = await getUserData( config.data.projectsStoragePrefix, - token, + user.id, storage, length ); @@ -176,17 +152,17 @@ function registerApiRoutes( app.get( prefix + "/last-searches/:length", - renkuAuth(authenticator), - async (req, res) => { - const token = req.headers[config.auth.authHeaderField] as string; - if (!token) { + async (req: RequestWithUser, res) => { + const user = req.user; + if (!user?.id) { res.json({ error: "User not authenticated" }); return; } + const length = parseInt(req.params["length"]) || 0; const data = await getUserData( config.data.searchStoragePrefix, - token, + user.id, storage, length ); @@ -212,26 +188,26 @@ function registerApiRoutes( */ app.get( prefix + "/projects/:projectName", - [renkuAuth(authenticator), lastProjectsMiddleware(storage)], + [lastProjectsMiddleware(storage)], proxyMiddleware ); app.post( prefix + "/renku/cache.files_upload", - [renkuAuth(authenticator), uploadFileMiddleware], + [uploadFileMiddleware], proxyMiddleware ); app.get( prefix + "/kg/entities", - [renkuAuth(authenticator), lastSearchQueriesMiddleware(storage)], + [lastSearchQueriesMiddleware(storage)], proxyMiddleware ); - app.delete(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.get(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.head(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.options(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.patch(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.post(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); - app.put(prefix + "/*", renkuAuth(authenticator), proxyMiddleware); + app.delete(prefix + "/*", proxyMiddleware); + app.get(prefix + "/*", proxyMiddleware); + app.head(prefix + "/*", proxyMiddleware); + app.options(prefix + "/*", proxyMiddleware); + app.patch(prefix + "/*", proxyMiddleware); + app.post(prefix + "/*", proxyMiddleware); + app.put(prefix + "/*", proxyMiddleware); } export default registerApiRoutes; diff --git a/server/src/routes/helperFunctions.ts b/server/src/routes/helperFunctions.ts index 79235ed9e9..3fe6191fdd 100644 --- a/server/src/routes/helperFunctions.ts +++ b/server/src/routes/helperFunctions.ts @@ -16,24 +16,22 @@ * limitations under the License. */ -import { getUserIdFromToken } from "../authentication"; import { Storage, StorageGetOptions, TypeData } from "../storage"; /** - * Get data from the Storage by user token + * Get data from the Storage by user ID * - * @param {string} prefix - the data prefix (StoragePrefix) - * @param {string} token - jwt token using bearer schema - * @param {Storage} storage - storage api - * @param {number} length - number of records, if the value <= 0 it will return all the user's records + * @param prefix - the data prefix (StoragePrefix) + * @param userId - user ID in Renku + * @param storage - storage api + * @param length - number of records, if the value <= 0 it will return all the user's records */ export async function getUserData( prefix: string, - token: string, + userId: string, storage: Storage, length = 0 ): Promise { - const userId = getUserIdFromToken(token); let data: string[] = []; const stop = length - 1; // -1 would bring all records const options: StorageGetOptions = { @@ -42,8 +40,9 @@ export async function getUserData( stop, }; - if (userId) + if (userId) { data = (await storage.get(`${prefix}${userId}`, options)) as string[]; + } return data; } diff --git a/server/src/routes/index.ts b/server/src/routes/index.ts index 4ae88e961b..32145b7b35 100644 --- a/server/src/routes/index.ts +++ b/server/src/routes/index.ts @@ -19,25 +19,24 @@ import express from "express"; import config from "../config"; -import registerInternalRoutes from "./internal"; -import registerApiRoutes from "./apis"; -import { Authenticator } from "../authentication"; import { Storage } from "../storage"; +import registerApiRoutes from "./apis"; +import registerInternalRoutes from "./internal"; + function register( app: express.Application, prefix: string, - authenticator: Authenticator, storage: Storage ): void { - registerInternalRoutes(app, authenticator); + registerInternalRoutes(app, storage); // Testing ingress app.get(prefix, (req, res) => { res.send("UI server up and running"); }); - registerApiRoutes(app, prefix + config.routes.api, authenticator, storage); + registerApiRoutes(app, prefix + config.routes.api, storage); } export default { register }; diff --git a/server/src/routes/internal.ts b/server/src/routes/internal.ts index 82739c4c4f..4ba9c6a107 100644 --- a/server/src/routes/internal.ts +++ b/server/src/routes/internal.ts @@ -17,15 +17,15 @@ */ import express from "express"; -import logger from "../logger"; -import { Authenticator } from "../authentication"; +import logger from "../logger"; +import { Storage } from "../storage"; let storageFailures = 0; function registerInternalRoutes( app: express.Application, - authenticator: Authenticator + storage: Storage ): void { // define a route handler for the default home page app.get("/", (req, res) => { @@ -40,20 +40,20 @@ function registerInternalRoutes( // define a route handler for the liveness probe app.get("/liveness", async (req, res) => { // Check storage status - const storageStatus = authenticator.storage.getStatus(); + const storageStatus = storage.getStatus(); if (storageStatus !== "ready") storageFailures++; else if (storageFailures !== 0) storageFailures = 0; if (storageFailures >= 5) { logger.error( - `Authentication storage failed ${storageFailures} times in a row. Sending a kill signal to k8s.` + `Storage failed ${storageFailures} times in a row. Sending a kill signal to k8s.` ); - res.status(503).send("Authentication storage failed."); + res.status(503).send("Storage failed."); return; } if (storageFailures >= 1) logger.warn( - `Authentication storage is failing. This is the attempt #${storageFailures}` + `Storage is failing. This is the attempt #${storageFailures}` ); res.send("live"); @@ -62,13 +62,13 @@ function registerInternalRoutes( // define a route handler for the startup probe app.get("/startup", (req, res) => { // check if storage is ready - if (!authenticator.storage.ready) + if (!storage.ready) { res.status(503).send("Storage (i.e. Redis) not ready"); - // check if authenticator is ready - else if (!authenticator.ready) - res.status(503).send("Authenticator not ready"); + } // if nothing bad happened so far... all must be working fine! - else res.send("live"); + else { + res.send("live"); + } }); } diff --git a/server/src/utils/middlewares/lastProjectsMiddleware.ts b/server/src/utils/middlewares/lastProjectsMiddleware.ts index 569c659b39..90ce74041a 100644 --- a/server/src/utils/middlewares/lastProjectsMiddleware.ts +++ b/server/src/utils/middlewares/lastProjectsMiddleware.ts @@ -19,7 +19,7 @@ import * as Sentry from "@sentry/node"; import express from "express"; -import { getUserIdFromToken } from "../../authentication"; +import type { RequestWithUser } from "../../authentication/authentication.types"; import config from "../../config"; import logger from "../../logger"; import { Storage, TypeData } from "../../storage"; @@ -31,11 +31,11 @@ function projectNameIsId(projectName: string): boolean { const lastProjectsMiddleware = (storage: Storage) => ( - req: express.Request, + req: RequestWithUser, res: express.Response, next: express.NextFunction ): void => { - const token = req.headers[config.auth.authHeaderField] as string; + const user = req.user; const projectName = req.params["projectName"]; // Ignore projects that are ids -- these will be re-accessed as namespace/name anyway if (projectNameIsId(projectName)) { @@ -45,17 +45,16 @@ const lastProjectsMiddleware = if (req.query?.doNotTrack !== "true") { res.on("finish", function () { - if (res.statusCode >= 400 || !token) { + if (res.statusCode >= 400 || !user?.id) { next(); return; } - const userId = getUserIdFromToken(token); const normalizedProjectName = projectName.toLowerCase(); // Save as ordered collection storage .save( - `${config.data.projectsStoragePrefix}${userId}`, + `${config.data.projectsStoragePrefix}${user.id}`, normalizedProjectName, { type: TypeData.Collections, @@ -65,7 +64,7 @@ const lastProjectsMiddleware = ) .then((value) => { if (!value) { - const errorMessage = `Error saving project ${projectName} for user ${userId}`; + const errorMessage = `Error saving project ${projectName} for user ${user.id}`; logger.error(errorMessage); Sentry.captureMessage(errorMessage); } @@ -78,4 +77,4 @@ const lastProjectsMiddleware = next(); }; -export { lastProjectsMiddleware, getUserIdFromToken }; +export { lastProjectsMiddleware }; diff --git a/server/src/utils/middlewares/lastSearchQueriesMiddleware.ts b/server/src/utils/middlewares/lastSearchQueriesMiddleware.ts index 57b2bcd588..d84f449ade 100644 --- a/server/src/utils/middlewares/lastSearchQueriesMiddleware.ts +++ b/server/src/utils/middlewares/lastSearchQueriesMiddleware.ts @@ -19,7 +19,7 @@ import * as Sentry from "@sentry/node"; import express from "express"; -import { getUserIdFromToken } from "../../authentication"; +import type { RequestWithUser } from "../../authentication/authentication.types"; import config from "../../config"; import logger from "../../logger"; import { Storage, TypeData } from "../../storage"; @@ -27,30 +27,29 @@ import { Storage, TypeData } from "../../storage"; const lastSearchQueriesMiddleware = (storage: Storage) => ( - req: express.Request, + req: RequestWithUser, res: express.Response, next: express.NextFunction ): void => { - const token = req.headers[config.auth.authHeaderField] as string; + const user = req.user; const query = req.query["query"]; const phrase = query ? (query as string).trim() : ""; if (req.query?.doNotTrack !== "true" && phrase) { res.on("finish", function () { - if (res.statusCode >= 400 || !token) { + if (res.statusCode >= 400 || !user?.id) { next(); return; } - const userId = getUserIdFromToken(token); storage - .save(`${config.data.searchStoragePrefix}${userId}`, phrase, { + .save(`${config.data.searchStoragePrefix}${user.id}`, phrase, { type: TypeData.Collections, limit: config.data.searchDefaultLength, score: Date.now(), }) .then((value) => { if (!value) { - const errorMessage = `Error saving search query for user ${userId}`; + const errorMessage = `Error saving search query for user ${user.id}`; logger.error(errorMessage); Sentry.captureMessage(errorMessage); } diff --git a/server/src/websocket/handlers/activationKgStatus.ts b/server/src/websocket/handlers/activationKgStatus.ts index d2f97ce554..0344419c7c 100644 --- a/server/src/websocket/handlers/activationKgStatus.ts +++ b/server/src/websocket/handlers/activationKgStatus.ts @@ -16,11 +16,12 @@ * limitations under the License. */ -import { Channel } from "../index"; -import { WsMessage } from "../WsMessages"; import APIClient from "../../api-client"; -import { AsyncSemaphore } from "../../utils/asyncSemaphore"; import logger from "../../logger"; +import { AsyncSemaphore } from "../../utils/asyncSemaphore"; +import { WsMessage } from "../WsMessages"; + +import type { Channel, WebSocketHandlerArgs } from "./handlers.types"; type ActivationMetadata = Record; @@ -47,7 +48,7 @@ function getActivationStatus( id: number, channel: Channel, apiClient: APIClient, - authHeaders: Headers + authHeaders: HeadersInit ) { return apiClient .kgActivationStatus(id, authHeaders) @@ -116,7 +117,7 @@ async function getAllActivationStatus( projectIds: ActivationMetadata, channel: Channel, apiClient: APIClient, - authHeaders: Headers + authHeaders: HeadersInit ): Promise { const semaphore = new AsyncSemaphore(5); const ids = Object.keys(projectIds); @@ -146,11 +147,11 @@ async function handlerRequestActivationKgStatus( } } -async function heartbeatRequestActivationKgStatus( - channel: Channel, - apiClient: APIClient, - authHeaders: Headers -): Promise { +async function heartbeatRequestActivationKgStatus({ + channel, + apiClient, + headers, +}: WebSocketHandlerArgs): Promise { const projectsIds = channel.data.get("projectsIds") as ActivationMetadata; if (projectsIds) { const previousStatuses = channel.data.get( @@ -161,7 +162,7 @@ async function heartbeatRequestActivationKgStatus( projectsIds, channel ); - getAllActivationStatus(ids, channel, apiClient, authHeaders); + getAllActivationStatus(ids, channel, apiClient, headers); } } diff --git a/server/src/websocket/handlers/clientVersion.ts b/server/src/websocket/handlers/clientVersion.ts index fb7bf697ab..ace3a9e964 100644 --- a/server/src/websocket/handlers/clientVersion.ts +++ b/server/src/websocket/handlers/clientVersion.ts @@ -22,6 +22,8 @@ import config from "../../config"; import { Channel } from "../index"; import { WsMessage } from "../WsMessages"; +import type { WebSocketHandlerArgs } from "./handlers.types"; + function handlerRequestServerVersion( data: Record, channel: Channel, @@ -59,7 +61,9 @@ function handlerRequestServerVersion( } } -function heartbeatRequestServerVersion(channel: Channel): void { +function heartbeatRequestServerVersion({ + channel, +}: WebSocketHandlerArgs): void { if (channel.data.get("requestServerVersion")) { const currentSha = process.env.RENKU_UI_SHORT_SHA ? process.env.RENKU_UI_SHORT_SHA diff --git a/server/src/websocket/handlers/handlers.types.ts b/server/src/websocket/handlers/handlers.types.ts new file mode 100644 index 0000000000..635ce5f661 --- /dev/null +++ b/server/src/websocket/handlers/handlers.types.ts @@ -0,0 +1,36 @@ +/*! + * Copyright 2024 - Swiss Data Science Center (SDSC) + * A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and + * Eidgenössische Technische Hochschule Zürich (ETHZ). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License + */ + +import type { WebSocket } from "ws"; + +import APIClient from "../../api-client"; + +export type WebSocketHandler = ( + args: WebSocketHandlerArgs +) => void | Promise; + +export interface WebSocketHandlerArgs { + channel: Channel; + apiClient: APIClient; + headers: HeadersInit; +} + +export interface Channel { + sockets: Array; + data: Map; +} diff --git a/server/src/websocket/handlers/sessions.ts b/server/src/websocket/handlers/sessions.ts index 73166e291b..87bd088844 100644 --- a/server/src/websocket/handlers/sessions.ts +++ b/server/src/websocket/handlers/sessions.ts @@ -16,13 +16,14 @@ * limitations under the License. */ -import logger from "../../logger"; -import { Channel } from "../index"; -import APIClient from "../../api-client"; import * as util from "util"; -import { WsMessage } from "../WsMessages"; + +import logger from "../../logger"; import { simpleHash, sortObjectProperties } from "../../utils"; +import { WsMessage } from "../WsMessages"; +import type { Channel, WebSocketHandlerArgs } from "./handlers.types"; + interface SessionsResult { servers: Record; } @@ -45,7 +46,7 @@ interface Session { } function handlerRequestSessionStatus( - data: Record, + _data: Record, channel: Channel ): void { channel.data.set("sessionStatus", null); @@ -56,14 +57,14 @@ function sendMessage(data: string, channel: Channel) { channel.sockets.forEach((socket) => socket.send(info.toString())); } -function heartbeatRequestSessionStatus( - channel: Channel, - apiClient: APIClient, - authHeathers: Record -): void { +function heartbeatRequestSessionStatus({ + channel, + apiClient, + headers, +}: WebSocketHandlerArgs): void { const previousStatuses = channel.data.get("sessionStatus") as string; apiClient - .getSessionStatus(authHeathers) + .getSessionStatus(headers) .then((response) => { const statusFetched = response as unknown as SessionsResult; const servers = statusFetched?.servers ?? {}; diff --git a/server/src/websocket/index.ts b/server/src/websocket/index.ts index adc8e0b72c..fc117d64e8 100644 --- a/server/src/websocket/index.ts +++ b/server/src/websocket/index.ts @@ -16,17 +16,16 @@ * limitations under the License. */ -import ws from "ws"; import * as SentryLib from "@sentry/node"; +import ws from "ws"; import APIClient from "../api-client"; -import { Authenticator } from "../authentication"; -import { wsRenkuAuth } from "../authentication/middleware"; import config from "../config"; import logger from "../logger"; import { Storage } from "../storage"; import { getCookieValueByName } from "../utils"; import { errorHandler } from "../utils/errorHandler"; + import { WsClientMessage, WsMessage, checkWsClientMessage } from "./WsMessages"; import { handlerRequestActivationKgStatus, @@ -40,15 +39,11 @@ import { handlerRequestSessionStatus, heartbeatRequestSessionStatus, } from "./handlers/sessions"; +import type { Channel, WebSocketHandler } from "./handlers/handlers.types"; // *** Channels *** // No need to store data in Redis since it's used only locally. We can modify this if necessary. -interface Channel { - sockets: Array; - data: Map; -} - const channels = new Map(); // *** Accepted messages *** @@ -98,10 +93,10 @@ const acceptedMessages: Record> = { // *** Heartbeats functions *** -// eslint-disable-next-line @typescript-eslint/ban-types -const longLoopFunctions: Array = [heartbeatRequestServerVersion]; -// eslint-disable-next-line @typescript-eslint/ban-types -const shortLoopFunctions: Array = [ +const longLoopFunctions: Array = [ + heartbeatRequestServerVersion, +]; +const shortLoopFunctions: Array = [ heartbeatRequestSessionStatus, heartbeatRequestActivationKgStatus, ]; @@ -110,13 +105,11 @@ const shortLoopFunctions: Array = [ * Long loop for each user -- executed every few minutes. * It automatically either reschedules when at least one channel is active, or close. * @param sessionId - user session ID - * @param authenticator - auth component * @param storage - storage component * @param apiClient - api to fetch data */ async function channelLongLoop( sessionId: string, - authenticator: Authenticator, storage: Storage, apiClient: APIClient ) { @@ -130,35 +123,25 @@ async function channelLongLoop( } // checking authentication - const timeoutLength = (config.websocket.longIntervalSec as number) * 1000; - if (!authenticator.ready) { + const timeoutLength = (config.websocket.longIntervalSec as number) * 1_000; + if (!storage.ready) { logger.info( - `${infoPrefix} Authenticator not ready yet, skipping to the next loop` + `${infoPrefix} Storage not ready yet, skipping to the next loop` ); setTimeout( - () => channelLongLoop(sessionId, authenticator, storage, apiClient), + () => channelLongLoop(sessionId, storage, apiClient), timeoutLength ); return false; } // get the auth headers - const authHeaders = await getAuthHeaders( - authenticator, - sessionId, - infoPrefix - ); - if (authHeaders instanceof WsMessage && authHeaders.data.expired) { - // ? here authHeaders is an error message - channel.sockets.forEach((socket) => socket.send(authHeaders.toString())); - channels.delete(sessionId); - return false; - } + const headers = { Cookie: `${config.auth.cookiesKey}=${sessionId}` }; for (const longLoopFunction of longLoopFunctions) { // execute the loop function try { - longLoopFunction(channel, apiClient, authHeaders); + longLoopFunction({ channel, apiClient, headers }); } catch (error) { const info = `Unexpected error while executing the function '${longLoopFunction.name}'.`; logger.error(`${infoPrefix} ${info}`); @@ -171,7 +154,7 @@ async function channelLongLoop( // Ping to keep the socket alive, then reschedule loop channel.sockets.forEach((socket) => socket.ping()); setTimeout( - () => channelLongLoop(sessionId, authenticator, storage, apiClient), + () => channelLongLoop(sessionId, storage, apiClient), timeoutLength ); } @@ -180,13 +163,11 @@ async function channelLongLoop( * Short loop for each user -- executed every few seconds. * It automatically either reschedules when at least one channel is active, or close. * @param sessionId - user session ID - * @param authenticator - auth component * @param storage - storage component * @param apiClient - api client */ async function channelShortLoop( sessionId: string, - authenticator: Authenticator, storage: Storage, apiClient: APIClient ) { @@ -200,35 +181,25 @@ async function channelShortLoop( } // checking authentication - const timeoutLength = (config.websocket.shortIntervalSec as number) * 1000; - if (!authenticator.ready) { + const timeoutLength = (config.websocket.shortIntervalSec as number) * 1_000; + if (!storage.ready) { logger.info( - `${infoPrefix} Authenticator not ready yet, skipping to the next loop` + `${infoPrefix} Storage not ready yet, skipping to the next loop` ); setTimeout( - () => channelShortLoop(sessionId, authenticator, storage, apiClient), + () => channelShortLoop(sessionId, storage, apiClient), timeoutLength ); return; } // get the auth headers - const authHeaders = await getAuthHeaders( - authenticator, - sessionId, - infoPrefix - ); - if (authHeaders instanceof WsMessage && authHeaders.data.expired) { - // ? here authHeaders is an error message - channel.sockets.forEach((socket) => socket.send(authHeaders.toString())); - channels.delete(sessionId); - return false; - } + const headers = { Cookie: `${config.auth.cookiesKey}=${sessionId}` }; for (const shortLoopFunction of shortLoopFunctions) { // execute the loop function try { - shortLoopFunction(channel, apiClient, authHeaders); + shortLoopFunction({ channel, apiClient, headers }); } catch (error) { const info = `Unexpected error while executing the function '${shortLoopFunction.name}'.`; logger.error(`${infoPrefix} ${info}`); @@ -241,7 +212,7 @@ async function channelShortLoop( // Ping to keep the socket alive, then reschedule loop channel.sockets.forEach((socket) => socket.ping()); setTimeout( - () => channelShortLoop(sessionId, authenticator, storage, apiClient), + () => channelShortLoop(sessionId, storage, apiClient), timeoutLength ); } @@ -253,13 +224,11 @@ async function channelShortLoop( /** * Configure WebSocket by setting up events and starting loops. * @param server - main wss server - * @param authenticator - auth component * @param storage - storage component * @param apiClient - api client */ function configureWebsocket( server: ws.Server, - authenticator: Authenticator, storage: Storage, apiClient: APIClient ): void { @@ -280,13 +249,13 @@ function configureWebsocket( SentryLib.setContext("WebSocket Initial Request", requestData); } - // get the user id + // get the session id const sessionId = getCookieValueByName( request.headers.cookie, config.auth.cookiesKey ); if (!sessionId) { - logger.error("No ID for the user, session won't be saved."); + logger.error("No session ID, session won't be saved."); const info = "The request does not contain a valid session ID." + " Are you reaching the WebSocket from an external source?"; @@ -306,7 +275,7 @@ function configureWebsocket( const channel = channels.get(sessionId); if (channel) { logger.debug( - `Adding a new socket to the channel for user ${sessionId}. Total of ${ + `Adding a new socket to the channel for the session ${sessionId}. Total of ${ channel.sockets.length + 1 }` ); @@ -315,17 +284,17 @@ function configureWebsocket( sockets: [...channel.sockets, socket], }); } else { - logger.debug(`Creating new channel for user ${sessionId}`); + logger.debug(`Creating new channel for the session ${sessionId}`); channels.set(sessionId, { sockets: [socket], data: new Map() }); // add a buffer before starting the loop, so we can receive setup messages setTimeout(() => { - channelShortLoop(sessionId, authenticator, storage, apiClient); + channelShortLoop(sessionId, storage, apiClient); // add a tiny buffer, in case authentication fails and channel is cleaned up -- no need to overlap setTimeout(() => { - channelLongLoop(sessionId, authenticator, storage, apiClient); - }, 1000); - }, config.websocket.delayStartSec * 1000); + channelLongLoop(sessionId, storage, apiClient); + }, 1_000); + }, config.websocket.delayStartSec * 1_000); } // event: close the socket @@ -333,14 +302,16 @@ function configureWebsocket( // (code, reason) might be used here // Verify session if (!sessionId) { - logger.debug("Nothing to cleanup for a user without ID."); + logger.debug("Nothing to cleanup when there is no session ID."); return false; } // Identify channel const channel = channels.get(sessionId); if (!channel) { - logger.warn(`No channel for user ${sessionId}. That is unexpected...`); + logger.warn( + `No channel for the session ${sessionId}. That is unexpected...` + ); return false; } @@ -349,8 +320,8 @@ function configureWebsocket( const remainingSockets = channel.sockets.length - 1; const remainingText = remainingSockets === 0 - ? `There are no channels left for the user ` - : `THere are other ${remainingSockets} socket(s) for the user `; + ? `There are no channels left for the sesssion ` + : `There are other ${remainingSockets} socket(s) for the session `; logger.debug(`Removing the channel. ${remainingText} ${sessionId}.`); const index = channel.sockets.indexOf(socket); if (index >= 0) @@ -361,7 +332,7 @@ function configureWebsocket( else logger.error("Socket not found."); } else { logger.info( - `Last socket for user ${sessionId}. Deleting the channel...` + `Last socket for the session ${sessionId}. Deleting the channel...` ); channels.delete(sessionId); } @@ -409,10 +380,6 @@ function configureWebsocket( } }); - // check auth - const head = await getAuthHeaders(authenticator, sessionId); - if (head instanceof WsMessage && head.data?.expired) - socket.send(head.toString()); socket.send( new WsMessage("Connection established.", "user", "init").toString() ); @@ -467,55 +434,4 @@ function getWsClientMessageHandler( return `Could not find a proper handler; data is wrong for a '${clientMessage.type}' instruction.`; } -/** - * Get auhtentication headers - * @param authenticator - auth component - * @param sessionId - user session ID - * @param infoPrefix - this is for the logger - * @returns error with WsMessage or headers - */ -async function getAuthHeaders( - authenticator: Authenticator, - sessionId: string, - infoPrefix = "" -): Promise> { - try { - const authHeaders = await wsRenkuAuth(authenticator, sessionId); - if (!authHeaders) - // user is anonymous - return null; - return authHeaders; - } catch (error) { - const data = { message: "authentication not valid" }; - let expiredMessage: WsMessage; - if (error.message.toString().includes("expired")) { - // Try to refresh tokens automatically - try { - logger.debug(`${infoPrefix} try to refresh tokens.`); - await authenticator.refreshTokens(sessionId); - const authHeaders = await wsRenkuAuth(authenticator, sessionId); - if (!authHeaders) - throw new Error("Cannot find auth headers after refreshing"); - logger.debug(`${infoPrefix} tokens refreshed.`); - return authHeaders; - } catch (internalError) { - logger.warn(`${infoPrefix} auth expired.`); - expiredMessage = new WsMessage( - { ...data, expired: true }, - "user", - "authentication" - ); - } - } else { - logger.warn(`${infoPrefix} auth invalid.`); - expiredMessage = new WsMessage( - { ...data, invalid: true }, - "user", - "authentication" - ); - } - return expiredMessage; - } -} - export { Channel, MessageData, configureWebsocket, getWsClientMessageHandler }; diff --git a/server/tests/storage/index.test.ts b/server/tests/storage/index.test.ts.old similarity index 100% rename from server/tests/storage/index.test.ts rename to server/tests/storage/index.test.ts.old