diff --git a/bun.lockb b/bun.lockb index c5e23c0..3f8c38c 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 1ef3e4f..5307482 100644 --- a/package.json +++ b/package.json @@ -3,16 +3,8 @@ "version": "2.4.0", "description": "A strategy to use and implement OAuth2 framework for authentication with federated services like Google, Facebook, GitHub, etc.", "license": "MIT", - "funding": [ - "https://github.com/sponsors/sergiodxa" - ], - "keywords": [ - "remix", - "remix-auth", - "auth", - "authentication", - "strategy" - ], + "funding": ["https://github.com/sponsors/sergiodxa"], + "keywords": ["remix", "remix-auth", "auth", "authentication", "strategy"], "author": { "name": "Sergio Xalambrí", "email": "hello+oss@sergiodxa.com", @@ -36,53 +28,34 @@ "sideEffects": false, "type": "module", "engines": { - "node": "^18.0.0 || ^20.0.0 || >=20.0.0" + "node": "^20.0.0 || >=20.0.0" }, - "files": [ - "build", - "package.json", - "README.md" - ], + "files": ["build", "package.json", "README.md"], "exports": { ".": "./build/index.js", "./package.json": "./package.json" }, "dependencies": { - "@oslojs/crypto": "^0.6.2", + "@mjackson/headers": "^0.8.0", + "@oslojs/crypto": "^1.0.1", "@oslojs/encoding": "^1.0.0", "@oslojs/oauth2": "^0.5.0", - "debug": "^4.3.4" + "arctic": "^2.3.0", + "debug": "^4.3.7" }, "peerDependencies": { - "@remix-run/cloudflare": "^1.0.0 || ^2.0.0", - "@remix-run/node": "^1.0.0 || ^2.0.0", - "@remix-run/deno": "^1.0.0 || ^2.0.0", - "remix-auth": "^3.6.0" - }, - "peerDependenciesMeta": { - "@remix-run/cloudflare": { - "optional": true - }, - "@remix-run/node": { - "optional": true - }, - "@remix-run/deno": { - "optional": true - } + "remix-auth": "^4.0.0" }, "devDependencies": { "@arethetypeswrong/cli": "^0.17.0", - "@biomejs/biome": "^1.8.3", - "@remix-run/node": "^2.8.1", - "@remix-run/server-runtime": "^2.8.1", + "@biomejs/biome": "^1.9.4", "@total-typescript/tsconfig": "^1.0.4", - "@types/bun": "^1.0.12", + "@types/bun": "^1.1.13", "@types/debug": "^4.1.12", - "consola": "^3.2.3", - "msw": "^2.2.13", - "remix-auth": "^3.6.0", - "typedoc": "^0.26.2", - "typedoc-plugin-mdn-links": "^3.1.25", - "typescript": "^5.4.5" + "msw": "^2.6.6", + "remix-auth": "^4.0.0", + "typedoc": "^0.26.11", + "typedoc-plugin-mdn-links": "^4.0.1", + "typescript": "^5.7.2" } } diff --git a/src/index.test.ts b/src/index.test.ts index 3e628d1..559f6e5 100644 --- a/src/index.test.ts +++ b/src/index.test.ts @@ -7,43 +7,26 @@ import { mock, test, } from "bun:test"; -import { createCookieSessionStorage, redirect } from "@remix-run/node"; -import { AuthenticateOptions, AuthorizationError } from "remix-auth"; -import { - OAuth2Error, - OAuth2Profile, - OAuth2Strategy, - OAuth2StrategyOptions, - OAuth2StrategyVerifyParams, -} from "."; +import { Cookie, SetCookie } from "@mjackson/headers"; +import { http, HttpResponse } from "msw"; +import { setupServer } from "msw/native"; +import { OAuth2Strategy } from "."; import { catchResponse } from "./test/helpers"; -import { server } from "./test/mock"; - -beforeAll(() => { - server.listen(); -}); - -afterEach(() => { - server.resetHandlers(); -}); - -afterAll(() => { - server.close(); -}); - -const BASE_OPTIONS: AuthenticateOptions = { - name: "form", - sessionKey: "user", - sessionErrorKey: "error", - sessionStrategyKey: "strategy", -}; +const server = setupServer( + http.post("https://example.app/token", async () => { + return HttpResponse.json({ + access_token: "mocked", + expires_in: 3600, + refresh_token: "mocked", + scope: ["user:email", "user:profile"].join(" "), + token_type: "Bearer", + }); + }), +); describe(OAuth2Strategy.name, () => { let verify = mock(); - let sessionStorage = createCookieSessionStorage({ - cookie: { secrets: ["s3cr3t"] }, - }); let options = Object.freeze({ authorizationEndpoint: "https://example.app/authorize", @@ -52,39 +35,41 @@ describe(OAuth2Strategy.name, () => { clientSecret: "MY_CLIENT_SECRET", redirectURI: "https://example.com/callback", scopes: ["user:email", "user:profile"], - codeChallengeMethod: "plain", - } satisfies OAuth2StrategyOptions); + } satisfies OAuth2Strategy.ConstructorOptions); interface User { id: string; } - interface TestProfile extends OAuth2Profile { - provider: "oauth2"; - } + beforeAll(() => { + server.listen(); + }); + + afterEach(() => { + server.resetHandlers(); + }); + + afterAll(() => { + server.close(); + }); test("should have the name `oauth2`", () => { - let strategy = new OAuth2Strategy(options, verify); + let strategy = new OAuth2Strategy(options, verify); expect(strategy.name).toBe("oauth2"); }); test("redirects to authorization url if there's no state", async () => { - let strategy = new OAuth2Strategy(options, verify); + let strategy = new OAuth2Strategy(options, verify); let request = new Request("https://remix.auth/login"); - let response = await catchResponse( - strategy.authenticate(request, sessionStorage, BASE_OPTIONS), - ); + let response = await catchResponse(strategy.authenticate(request)); // biome-ignore lint/style/noNonNullAssertion: This is a test let redirect = new URL(response.headers.get("location")!); - let session = await sessionStorage.getSession( - response.headers.get("set-cookie"), - ); - - expect(response.status).toBe(302); + let setCookie = new SetCookie(response.headers.get("set-cookie") ?? ""); + let params = new URLSearchParams(setCookie.value); expect(redirect.pathname).toBe("/authorize"); expect(redirect.searchParams.get("response_type")).toBe("code"); @@ -93,281 +78,88 @@ describe(OAuth2Strategy.name, () => { expect(redirect.searchParams.has("state")).toBeTruthy(); expect(redirect.searchParams.get("scope")).toBe(options.scopes.join(" ")); - expect(session.get("oauth2:state")).toBe( - redirect.searchParams.get("state"), - ); + expect(params.get("state")).toBe(redirect.searchParams.get("state")); - expect(session.get("oauth2:codeVerifier")).toBe( - redirect.searchParams.get("code_challenge"), - ); + // expect(params.get("codeVerifier")).toBe( + // redirect.searchParams.get("code_challenge"), + // ); - expect(redirect.searchParams.get("code_challenge_method")).toBe("plain"); + expect(redirect.searchParams.get("code_challenge_method")).toBe("S256"); }); test("throws if there's no state in the session", async () => { - let strategy = new OAuth2Strategy(options, verify); + let strategy = new OAuth2Strategy(options, verify); let request = new Request( "https://example.com/callback?state=random-state&code=random-code", ); - let response = await catchResponse( - strategy.authenticate(request, sessionStorage, BASE_OPTIONS), + expect(strategy.authenticate(request)).rejects.toThrowError( + new ReferenceError("Missing state on cookie."), ); - - expect(response.status).toBe(401); - await expect(response.json()).resolves.toEqual({ - message: "Missing state on session.", - }); }); test("throws if the state in the url doesn't match the state in the session", async () => { - let strategy = new OAuth2Strategy(options, verify); + let strategy = new OAuth2Strategy(options, verify); - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); + let cookie = new Cookie(); + cookie.set( + "oauth2", + new URLSearchParams({ state: "random-state" }).toString(), + ); let request = new Request( "https://example.com/callback?state=another-state&code=random-code", - { headers: { cookie: await sessionStorage.commitSession(session) } }, + { headers: { Cookie: cookie.toString() } }, ); - let response = await catchResponse( - strategy.authenticate(request, sessionStorage, BASE_OPTIONS), + expect(strategy.authenticate(request)).rejects.toThrowError( + new ReferenceError("State in URL doesn't match state in cookie."), ); - - expect(response.status).toBe(401); - - let data = await response.json(); - - expect(data).toEqual({ - message: "State in URL doesn't match state in session.", - }); }); - test("calls verify with the tokens, user profile, context and request", async () => { - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); + test("calls verify with the tokens and request", async () => { + let strategy = new OAuth2Strategy(options, verify); - let request = new Request( - "https://example.com/callback?state=random-state&code=random-code", - { - headers: { cookie: await sessionStorage.commitSession(session) }, - }, + let cookie = new Cookie(); + cookie.set( + "oauth2", + new URLSearchParams({ + state: "random-state", + codeVerifier: "random-code-verifier", + }).toString(), ); - let context = { test: "it works" }; - await strategy - .authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - context, - }) - .catch((error) => error); - - expect(verify).toHaveBeenLastCalledWith({ - tokens: { - access_token: "mocked", - expires_in: 3600, - refresh_token: "mocked", - scope: "user:email user:profile", - token_type: "Bearer", - }, - profile: { provider: "oauth2" }, - context, - request, - } satisfies OAuth2StrategyVerifyParams< - OAuth2Profile, - Record - >); - }); - - test("returns the result of verify", async () => { - let user = { id: "123" }; - verify.mockResolvedValueOnce(user); - - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - let request = new Request( "https://example.com/callback?state=random-state&code=random-code", - { headers: { cookie: await sessionStorage.commitSession(session) } }, + { headers: { cookie: cookie.toString() } }, ); - let response = await strategy.authenticate( - request, - sessionStorage, - BASE_OPTIONS, - ); + await strategy.authenticate(request); - expect(response).toEqual(user); + expect(verify).toHaveBeenCalled(); }); - test("throws a response with user in session and redirect to /", async () => { + test("returns the result of verify", () => { let user = { id: "123" }; verify.mockResolvedValueOnce(user); - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - - let request = new Request( - "https://example.com/callback?state=random-state&code=random-code", - { - headers: { cookie: await sessionStorage.commitSession(session) }, - }, - ); - - let response = await catchResponse( - strategy.authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - successRedirect: "/", - }), - ); + let strategy = new OAuth2Strategy(options, verify); - session = await sessionStorage.getSession( - response.headers.get("Set-Cookie"), + let cookie = new Cookie(); + cookie.set( + "oauth2", + new URLSearchParams({ + state: "random-state", + codeVerifier: "random-code-verifier", + }).toString(), ); - expect(response.headers.get("Location")).toBe("/"); - expect(session.get("user")).toEqual(user); - }); - - test("pass error as cause on failure", async () => { - verify.mockRejectedValueOnce(new TypeError("Invalid credentials")); - - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - let request = new Request( "https://example.com/callback?state=random-state&code=random-code", - { - headers: { cookie: await sessionStorage.commitSession(session) }, - }, - ); - - let result = await strategy - .authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - throwOnError: true, - }) - .catch((error) => error); - - expect(result).toEqual(new AuthorizationError("Invalid credentials")); - expect((result as AuthorizationError).cause).toEqual( - new TypeError("Invalid credentials"), + { headers: { cookie: cookie.toString() } }, ); - }); - - test("pass generate error from string on failure", async () => { - verify.mockRejectedValueOnce("Invalid credentials"); - - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - - let request = new Request( - "https://example.com/callback?state=random-state&code=random-code", - { - headers: { cookie: await sessionStorage.commitSession(session) }, - }, - ); - - let result = await strategy - .authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - throwOnError: true, - }) - .catch((error) => error); - - expect(result).toEqual(new AuthorizationError("Invalid credentials")); - expect((result as AuthorizationError).cause).toEqual( - new Error("Invalid credentials"), - ); - }); - - test("creates Unknown error if thrown value is not Error or string", async () => { - verify.mockRejectedValueOnce({ message: "Invalid email address" }); - - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - let request = new Request( - "https://example.com/callback?state=random-state&code=random-code", - { - headers: { cookie: await sessionStorage.commitSession(session) }, - }, - ); - - let result = await strategy - .authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - throwOnError: true, - }) - .catch((error) => error); - - expect(result).toEqual(new AuthorizationError("Unknown error")); - expect((result as AuthorizationError).cause).toEqual( - new Error(JSON.stringify({ message: "Invalid email address" }, null, 2)), - ); - }); - - test("thrown response in verify callback should pass-through", async () => { - verify.mockRejectedValueOnce(redirect("/test")); - - let strategy = new OAuth2Strategy(options, verify); - - let session = await sessionStorage.getSession(); - session.set("oauth2:state", "random-state"); - - let request = new Request( - "https://example.com/callback?state=random-state&code=random-code", - { headers: { cookie: await sessionStorage.commitSession(session) } }, - ); - - let response = await strategy - .authenticate(request, sessionStorage, BASE_OPTIONS) - .then(() => { - throw new Error("Should have failed."); - }) - .catch((error: unknown) => { - if (error instanceof Response) return error; - throw error; - }); - - expect(response.status).toEqual(302); - expect(response.headers.get("location")).toEqual("/test"); - }); - - test("throws if there's an error in the url", async () => { - let strategy = new OAuth2Strategy(options, verify); - - let request = new Request( - "https://example.com/callback?error=invalid_request", - ); - - expect(() => - strategy.authenticate(request, sessionStorage, { - ...BASE_OPTIONS, - throwOnError: true, - }), - ).toThrowError( - // @ts-expect-error - This is a test - new AuthorizationError("Error on authentication", { - cause: new OAuth2Error(request, { - error: "invalid_request", - error_description: undefined, - }), - }), - ); + expect(strategy.authenticate(request)).resolves.toEqual(user); }); }); diff --git a/src/index.ts b/src/index.ts index f593521..3b3e1cd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,333 +1,155 @@ import { - type AppLoadContext, - type SessionStorage, - redirect, -} from "@remix-run/server-runtime"; -import createDebug from "debug"; + Cookie, + type CookieInit, + SetCookie, + type SetCookieInit, +} from "@mjackson/headers"; import { - type AuthenticateOptions, - Strategy, - type StrategyVerifyCallback, -} from "remix-auth"; -import { AuthorizationCode } from "./lib/authorization-code.js"; -import { Generator } from "./lib/generator.js"; -import { OAuth2Request } from "./lib/request.js"; -import { Token } from "./lib/token.js"; + CodeChallengeMethod, + OAuth2Client, + OAuth2RequestError, + type OAuth2Tokens, + generateCodeVerifier, + generateState, +} from "arctic"; +import createDebug from "debug"; +import { Strategy } from "remix-auth/strategy"; +import { redirect } from "./lib/redirect.js"; let debug = createDebug("OAuth2Strategy"); -export interface OAuth2Profile { - provider: string; - id?: string; - displayName?: string; - name?: { - familyName?: string; - givenName?: string; - middleName?: string; - }; - emails?: Array<{ - value: string; - type?: string; - }>; - photos?: Array<{ value: string }>; -} - type URLConstructor = ConstructorParameters[0]; -export interface OAuth2StrategyOptions { - /** - * This is the Client ID of your application, provided to you by the Identity - * Provider you're using to authenticate users. - */ - clientId: string; - /** - * This is the Client Secret of your application, provided to you by the - * Identity Provider you're using to authenticate users. - */ - clientSecret: string; - - /** - * The endpoint the Identity Provider asks you to send users to log in, or - * authorize your application. - */ - authorizationEndpoint: URLConstructor; - /** - * The endpoint the Identity Provider uses to let's you exchange an access - * code for an access and refresh token. - */ - tokenEndpoint: URLConstructor; - /** - * The URL of your application where the Identity Provider will redirect the - * user after they've logged in or authorized your application. - */ - redirectURI: URLConstructor; - - /** - * The endpoint the Identity Provider uses to revoke an access or refresh - * token, this can be useful to log out the user. - */ - tokenRevocationEndpoint?: URLConstructor; - - /** - * The scopes you want to request from the Identity Provider, this is a list - * of strings that represent the permissions you want to request from the - * user. - */ - scopes?: string[]; - - /** - * The code challenge method to use when sending the authorization request. - * This is used when the Identity Provider requires a code challenge to be - * sent with the authorization request. - * @default "S256" - */ - codeChallengeMethod?: "S256" | "plain"; - - /** - * The method to use to authenticate with the Identity Provider, this can be - * either `http_basic_auth` or `request_body`. - * @default "request_body" - */ - authenticateWith?: "http_basic_auth" | "request_body"; -} - -export interface OAuth2StrategyVerifyParams< - Profile extends OAuth2Profile, - ExtraTokenParams extends Record = Record, -> { - tokens: Token.Response.Body & ExtraTokenParams; - profile: Profile; - request: Request; - context?: AppLoadContext; -} - -export class OAuth2Strategy< +export class OAuth2Strategy extends Strategy< User, - Profile extends OAuth2Profile, - ExtraParams extends Record = Record, -> extends Strategy> { - name = "oauth2"; + OAuth2Strategy.VerifyOptions +> { + override name = "oauth2"; - protected sessionStateKey = "oauth2:state"; - protected sessionCodeVerifierKey = "oauth2:codeVerifier"; - protected options: OAuth2StrategyOptions; + protected client: OAuth2Client; constructor( - options: OAuth2StrategyOptions, - verify: StrategyVerifyCallback< - User, - OAuth2StrategyVerifyParams - >, + protected options: OAuth2Strategy.ConstructorOptions, + verify: Strategy.VerifyFunction, ) { super(verify); - this.options = { - codeChallengeMethod: "S256", - authenticateWith: "request_body", - ...options, - }; + + this.client = new OAuth2Client( + options.clientId, + options.clientSecret, + options.redirectURI.toString(), + ); } - async authenticate( - request: Request, - sessionStorage: SessionStorage, - options: AuthenticateOptions, - ): Promise { - debug("Request URL", request.url); + private get cookieName() { + if (typeof this.options.cookie === "string") { + return this.options.cookie || "oauth2"; + } + return this.options.cookie?.name ?? "oauth2"; + } - let url = new URL(request.url); + private get cookieOptions() { + if (typeof this.options.cookie !== "object") return {}; + return this.options.cookie ?? {}; + } - if (url.searchParams.has("error")) { - return this.failure( - "Error on authentication", - request, - sessionStorage, - options, - new OAuth2Error(request, { - error: url.searchParams.get("error") ?? undefined, - error_description: - url.searchParams.get("error_description") ?? undefined, - error_uri: url.searchParams.get("error_uri") ?? undefined, - }), - ); - } + override async authenticate(request: Request): Promise { + debug("Request URL", request.url); - let session = await sessionStorage.getSession( - request.headers.get("Cookie"), - ); + let url = new URL(request.url); let stateUrl = url.searchParams.get("state"); + let error = url.searchParams.get("error"); + + if (error) { + let description = url.searchParams.get("error_description"); + let uri = url.searchParams.get("error_uri"); + throw new OAuth2RequestError(error, description, uri, stateUrl); + } if (!stateUrl) { debug("No state found in the URL, redirecting to authorization endpoint"); - let state = Generator.state(); - session.set(this.sessionStateKey, state); + let { state, codeVerifier, url } = this.createAuthorizationURL(); debug("State", state); - - let codeVerifier = Generator.codeVerifier(); - session.set(this.sessionCodeVerifierKey, codeVerifier); - debug("Code verifier", codeVerifier); - let authorizationURL = new AuthorizationCode.AuthorizationURL( - this.options.authorizationEndpoint.toString(), - this.options.clientId, - ); - - authorizationURL.setRedirectURI(this.options.redirectURI.toString()); - authorizationURL.setState(state); - - if (this.options.scopes) - authorizationURL.addScopes(...this.options.scopes); - - if (this.options.codeChallengeMethod === "S256") { - authorizationURL.setS256CodeChallenge(codeVerifier); - } else if (this.options.codeChallengeMethod === "plain") { - authorizationURL.setPlainCodeChallenge(codeVerifier); - } - - // Extend authorization URL with extra non-standard params - authorizationURL.search = this.authorizationParams( - authorizationURL.searchParams, + url.search = this.authorizationParams( + url.searchParams, request, ).toString(); - debug("Authorization URL", authorizationURL.toString()); + debug("Authorization URL", url.toString()); + + let header = new SetCookie({ + name: this.cookieName, + value: new URLSearchParams({ state, codeVerifier }).toString(), + httpOnly: true, // Prevents JavaScript from accessing the cookie + maxAge: 60 * 5, // 5 minutes + path: "/", // Allow the cookie to be sent to any path + sameSite: "Lax", // Prevents it from being sent in cross-site requests + ...this.cookieOptions, + }); - throw redirect(authorizationURL.toString(), { - headers: { - "Set-Cookie": await sessionStorage.commitSession(session), - }, + throw redirect(url.toString(), { + headers: { "Set-Cookie": header.toString() }, }); } let code = url.searchParams.get("code"); - let codeVerifier = session.get(this.sessionCodeVerifierKey); - if (!code && url.searchParams.has("error")) { - return this.failure( - "Error during authentication", - request, - sessionStorage, - options, - new OAuth2Error(request, { - error: url.searchParams.get("error") ?? undefined, - error_description: - url.searchParams.get("error_description") ?? undefined, - error_uri: url.searchParams.get("error_uri") ?? undefined, - }), - ); - } + if (!code) throw new ReferenceError("Missing code in the URL"); - if (!code) { - return this.failure( - "Missing code in the URL", - request, - sessionStorage, - options, - new ReferenceError("Missing code in the URL"), - ); + let cookie = new Cookie(request.headers.get("cookie") ?? ""); + let params = new URLSearchParams(cookie.get(this.cookieName)); + + if (!params.has("state")) { + throw new ReferenceError("Missing state on cookie."); } - let stateSession = session.get(this.sessionStateKey); - debug("State from session", stateSession); - if (!stateSession) { - return await this.failure( - "Missing state on session.", - request, - sessionStorage, - options, - new ReferenceError("Missing state on session."), - ); + if (params.get("state") !== stateUrl) { + throw new RangeError("State in URL doesn't match state in cookie."); } - if (stateSession === stateUrl) { - debug("State is valid"); - session.unset(this.sessionStateKey); - } else { - return await this.failure( - "State in URL doesn't match state in session.", - request, - sessionStorage, - options, - new RangeError("State in URL doesn't match state in session."), - ); + if (!params.has("codeVerifier")) { + throw new ReferenceError("Missing code verifier on cookie."); } - try { - debug("Validating authorization code"); - let context = new Token.Request.Context(code); - - context.setRedirectURI(this.options.redirectURI.toString()); - context.setCodeVerifier(codeVerifier); - - if (this.options.authenticateWith === "http_basic_auth") { - context.authenticateWithHTTPBasicAuth( - this.options.clientId, - this.options.clientSecret, - ); - } else if (this.options.authenticateWith === "request_body") { - context.authenticateWithRequestBody( - this.options.clientId, - this.options.clientSecret, - ); - } - - let tokens = await Token.Request.send( - this.options.tokenEndpoint.toString(), - context, - { signal: request.signal }, - ); - - debug("Fetching the user profile"); - let profile = await this.userProfile(tokens); - - debug("Verifying the user profile"); - let user = await this.verify({ - tokens, - profile, - context: options.context, - request, - }); + debug("Validating authorization code"); + let tokens = await this.validateAuthorizationCode( + code, + params.get("codeVerifier") as string, // We checked above this is defined + ); - debug("User authenticated"); - return this.success(user, request, sessionStorage, options); - } catch (error) { - // Allow responses to pass-through - if (error instanceof Response) throw error; - - debug("Failed to verify user", error); - if (error instanceof Error) { - return await this.failure( - error.message, - request, - sessionStorage, - options, - error, - ); - } - if (typeof error === "string") { - return await this.failure( - error, - request, - sessionStorage, - options, - new Error(error), - ); - } - return await this.failure( - "Unknown error", - request, - sessionStorage, - options, - new Error(JSON.stringify(error, null, 2)), - ); - } + debug("Verifying the user profile"); + let user = await this.verify({ request, tokens }); + + debug("User authenticated"); + return user; + } + + protected createAuthorizationURL() { + let state = generateState(); + let codeVerifier = generateCodeVerifier(); + + let url = this.client.createAuthorizationURLWithPKCE( + this.options.authorizationEndpoint.toString(), + state, + this.options.codeChallengeMethod ?? CodeChallengeMethod.S256, + codeVerifier, + this.options.scopes ?? [], + ); + + return { state, codeVerifier, url }; } - protected async userProfile(tokens: Token.Response.Body): Promise { - return { provider: "oauth2" } as Profile; + protected validateAuthorizationCode(code: string, codeVerifier: string) { + return this.client.validateAuthorizationCode( + this.options.tokenEndpoint.toString(), + code, + codeVerifier, + ); } /** @@ -346,98 +168,89 @@ export class OAuth2Strategy< return new URLSearchParams(params); } - /** - * Get new tokens using a refresh token. - * @param refreshToken The refresh token to use - * @param options Optional options to override the default strategy options - * @returns A promise that resolves to the new tokens - */ - public refreshToken( - refreshToken: string, - options: Partial> & { - signal?: AbortSignal; - } = {}, - ) { - let scopes = options.scopes ?? this.options.scopes ?? []; - - let context = new Token.RefreshRequest.Context(refreshToken); - - context.addScopes(...scopes); - - if (this.options.authenticateWith === "http_basic_auth") { - context.authenticateWithHTTPBasicAuth( - this.options.clientId, - this.options.clientSecret, - ); - } else if (this.options.authenticateWith === "request_body") { - context.authenticateWithRequestBody( - this.options.clientId, - this.options.clientSecret, - ); - } - - return Token.Request.send( + public refreshToken(refreshToken: string) { + return this.client.refreshAccessToken( this.options.tokenEndpoint.toString(), - context, - { signal: options.signal }, + refreshToken, + this.options.scopes ?? [], ); } - public async revokeToken( - token: string, - options: { - signal?: AbortSignal; - tokenType?: "access_token" | "refresh_token"; - } = {}, - ) { - if (this.options.tokenRevocationEndpoint === undefined) { - throw new Error("Token revocation endpoint is not set"); - } - - let context = new Token.RevocationRequest.Context(token); - - if (options.tokenType) context.setTokenTypeHint(options.tokenType); - - if (this.options.authenticateWith === "http_basic_auth") { - context.authenticateWithHTTPBasicAuth( - this.options.clientId, - this.options.clientSecret, - ); - } else if (this.options.authenticateWith === "request_body") { - context.authenticateWithRequestBody( - this.options.clientId, - this.options.clientSecret, - ); - } - - await Token.RevocationRequest.send( - this.options.tokenRevocationEndpoint, - context, - { signal: options.signal }, - ); + public revokeToken(token: string) { + let endpoint = this.options.tokenRevocationEndpoint; + if (!endpoint) throw new Error("Token revocation endpoint is not set."); + return this.client.revokeToken(endpoint.toString(), token); } } -export interface TokenErrorResponseBody { - error: string; - error_description?: string; - error_uri?: string; -} - -export class OAuth2Error extends Error { - override name = "OAuth2Error"; - - public request: Request; - public description: string | null; - public uri: string | null; +export namespace OAuth2Strategy { + export interface VerifyOptions { + /** The request that triggered the verification flow */ + request: Request; + /** The OAuth2 tokens retrivied from the identity provider */ + tokens: OAuth2Tokens; + } - constructor(request: Request, body: Partial) { - super(body.error ?? ""); - this.request = request; - this.description = body.error_description ?? null; - this.uri = body.error_uri ?? null; + export interface ConstructorOptions { + /** + * The name of the cookie used to keep state and code verifier around. + * + * The OAuth2 flow requires generating a random state and code verifier, and + * then checking that the state matches when the user is redirected back to + * the application. This is done to prevent CSRF attacks. + * + * The state and code verifier are stored in a cookie, and this option + * allows you to customize the name of that cookie if needed. + * @default "oauth2" + */ + cookie?: string | (Omit & { name: string }); + + /** + * This is the Client ID of your application, provided to you by the Identity + * Provider you're using to authenticate users. + */ + clientId: string; + /** + * This is the Client Secret of your application, provided to you by the + * Identity Provider you're using to authenticate users. + */ + clientSecret: string; + + /** + * The endpoint the Identity Provider asks you to send users to log in, or + * authorize your application. + */ + authorizationEndpoint: URLConstructor; + /** + * The endpoint the Identity Provider uses to let's you exchange an access + * code for an access and refresh token. + */ + tokenEndpoint: URLConstructor; + /** + * The URL of your application where the Identity Provider will redirect the + * user after they've logged in or authorized your application. + */ + redirectURI: URLConstructor; + + /** + * The endpoint the Identity Provider uses to revoke an access or refresh + * token, this can be useful to log out the user. + */ + tokenRevocationEndpoint?: URLConstructor; + + /** + * The scopes you want to request from the Identity Provider, this is a list + * of strings that represent the permissions you want to request from the + * user. + */ + scopes?: string[]; + + /** + * The code challenge method to use when sending the authorization request. + * This is used when the Identity Provider requires a code challenge to be + * sent with the authorization request. + * @default "CodeChallengeMethod.S256" + */ + codeChallengeMethod?: CodeChallengeMethod; } } - -export const OAuth2RequestError = OAuth2Request.Error; -export type TokenResponseBody = Token.Response.Body; diff --git a/src/lib/authorization-code.ts b/src/lib/authorization-code.ts deleted file mode 100644 index 1751324..0000000 --- a/src/lib/authorization-code.ts +++ /dev/null @@ -1,48 +0,0 @@ -/** - * A lot of the code here was originally implemented by @pilcrowOnPaper for a - * previous version of `@oslojs/oauth2`, as Pilcrow decided to change the - * direction of the library to focus on response parsing, I decided to copy the - * old code and adapt it to the new structure of the library. - */ -import { sha256 } from "@oslojs/crypto/sha2"; -import { encodeBase64urlNoPadding } from "@oslojs/encoding"; - -export namespace AuthorizationCode { - export class AuthorizationURL extends URL { - constructor(authorizationEndpoint: string, clientId: string) { - super(authorizationEndpoint); - this.searchParams.set("response_type", "code"); - this.searchParams.set("client_id", clientId); - } - - public setRedirectURI(redirectURI: string): void { - this.searchParams.set("redirect_uri", redirectURI); - } - - public addScopes(...scopes: string[]): void { - if (scopes.length < 1) { - return; - } - let scopeValue = scopes.join(" "); - const existingScopes = this.searchParams.get("scope"); - if (existingScopes !== null) scopeValue = ` ${existingScopes}`; - this.searchParams.set("scope", scopeValue); - } - - public setState(state: string): void { - this.searchParams.set("state", state); - } - - public setS256CodeChallenge(codeVerifier: string): void { - const codeChallengeBytes = sha256(new TextEncoder().encode(codeVerifier)); - const codeChallenge = encodeBase64urlNoPadding(codeChallengeBytes); - this.searchParams.set("code_challenge", codeChallenge); - this.searchParams.set("code_challenge_method", "S256"); - } - - public setPlainCodeChallenge(codeVerifier: string): void { - this.searchParams.set("code_challenge", codeVerifier); - this.searchParams.set("code_challenge_method", "plain"); - } - } -} diff --git a/src/lib/generator.ts b/src/lib/generator.ts deleted file mode 100644 index e4fd70a..0000000 --- a/src/lib/generator.ts +++ /dev/null @@ -1,21 +0,0 @@ -/** - * A lot of the code here was originally implemented by @pilcrowOnPaper for a - * previous version of `@oslojs/oauth2`, as Pilcrow decided to change the - * direction of the library to focus on response parsing, I decided to copy the - * old code and adapt it to the new structure of the library. - */ -import { encodeBase64urlNoPadding } from "@oslojs/encoding"; - -export namespace Generator { - export function codeVerifier(): string { - const randomValues = new Uint8Array(32); - crypto.getRandomValues(randomValues); - return encodeBase64urlNoPadding(randomValues); - } - - export function state(): string { - const randomValues = new Uint8Array(32); - crypto.getRandomValues(randomValues); - return encodeBase64urlNoPadding(randomValues); - } -} diff --git a/src/lib/redirect.ts b/src/lib/redirect.ts new file mode 100644 index 0000000..fd1508c --- /dev/null +++ b/src/lib/redirect.ts @@ -0,0 +1,14 @@ +export function redirect(url: string, init: ResponseInit | number = 302) { + let responseInit = init; + + if (typeof responseInit === "number") { + responseInit = { status: responseInit }; + } else if (typeof responseInit.status === "undefined") { + responseInit.status = 302; + } + + let headers = new Headers(responseInit.headers); + headers.set("Location", url); + + return new Response(null, { ...responseInit, headers }); +} diff --git a/src/lib/request.ts b/src/lib/request.ts deleted file mode 100644 index 1e0d56a..0000000 --- a/src/lib/request.ts +++ /dev/null @@ -1,76 +0,0 @@ -/** - * A lot of the code here was originally implemented by @pilcrowOnPaper for a - * previous version of `@oslojs/oauth2`, as Pilcrow decided to change the - * direction of the library to focus on response parsing, I decided to copy the - * old code and adapt it to the new structure of the library. - */ -import { encodeBase64 } from "@oslojs/encoding"; - -export namespace OAuth2Request { - export abstract class Context { - public method: string; - public body = new URLSearchParams(); - public headers = new Headers(); - - constructor(method: string) { - this.method = method; - this.headers.set("Content-Type", "application/x-www-form-urlencoded"); - this.headers.set("Accept", "application/json"); - this.headers.set("User-Agent", "oslo"); - } - - public setClientId(clientId: string): void { - this.body.set("client_id", clientId); - } - - public authenticateWithRequestBody( - clientId: string, - clientSecret: string, - ): void { - this.setClientId(clientId); - this.body.set("client_secret", clientSecret); - } - - public authenticateWithHTTPBasicAuth( - clientId: string, - clientSecret: string, - ): void { - const authorizationHeader = `Basic ${encodeBase64( - new TextEncoder().encode(`${clientId}:${clientSecret}`), - )}`; - this.headers.set("Authorization", authorizationHeader); - } - - toRequest(url: ConstructorParameters["0"]) { - return new Request(url, { - method: this.method, - body: this.body, - headers: this.headers, - }); - } - } - - // biome-ignore lint/suspicious/noShadowRestrictedNames: It's namespaced - export class Error extends globalThis.Error { - public request: Request; - public context: OAuth2Request.Context; - public description: string | null; - public uri: string | null; - public responseHeaders: Headers; - - constructor( - message: string, - request: Request, - context: OAuth2Request.Context, - responseHeaders: Headers, - options?: { description?: string; uri?: string }, - ) { - super(message); - this.request = request; - this.context = context; - this.responseHeaders = responseHeaders; - this.description = options?.description ?? null; - this.uri = options?.uri ?? null; - } - } -} diff --git a/src/lib/token.ts b/src/lib/token.ts deleted file mode 100644 index be8f177..0000000 --- a/src/lib/token.ts +++ /dev/null @@ -1,159 +0,0 @@ -/** - * A lot of the code here was originally implemented by @pilcrowOnPaper for a - * previous version of `@oslojs/oauth2`, as Pilcrow decided to change the - * direction of the library to focus on response parsing, I decided to copy the - * old code and adapt it to the new structure of the library. - */ -import { OAuth2RequestResult, TokenRequestResult } from "@oslojs/oauth2"; -import { OAuth2Request } from "./request.js"; - -type URLConstructor = ConstructorParameters[0]; - -export namespace Token { - export namespace Response { - export interface Body { - access_token: string; - token_type: string; - expires_in?: number; - refresh_token?: string; - scope?: string; - } - - export interface ErrorBody { - error: string; - error_description?: string; - } - } - - export namespace Request { - export class Context extends OAuth2Request.Context { - constructor(authorizationCode: string) { - super("POST"); - this.body.set("grant_type", "authorization_code"); - this.body.set("code", authorizationCode); - } - - public setCodeVerifier(codeVerifier: string): void { - this.body.set("code_verifier", codeVerifier); - } - - public setRedirectURI(redirectURI: string): void { - this.body.set("redirect_uri", redirectURI); - } - } - - export async function send>( - endpoint: URLConstructor, - context: OAuth2Request.Context, - options?: { signal?: AbortSignal }, - ): Promise { - let request = context.toRequest(endpoint); - let response = await fetch(request, { signal: options?.signal }); - let body = await response.json(); - - let result = new Result(body); - - if (result.hasErrorCode()) { - throw new OAuth2Request.Error( - result.errorCode(), - request, - context, - response.headers, - { - description: result.hasErrorDescription() - ? result.errorDescription() - : undefined, - uri: result.hasErrorURI() ? result.errorURI() : undefined, - }, - ); - } - - return result.toJSON(); - } - - export class Result< - ExtraParams extends Record, - > extends TokenRequestResult { - toJSON(): Response.Body & ExtraParams { - return { - ...this.body, - access_token: this.accessToken(), - token_type: this.tokenType(), - ...("expires_in" in this.body && { - expires_in: this.accessTokenExpiresInSeconds(), - }), - ...(this.hasScopes() && { scope: this.scopes().join(" ") }), - ...(this.hasRefreshToken() && { refresh_token: this.refreshToken() }), - } as Response.Body & ExtraParams; - } - } - } - - export namespace RevocationRequest { - export class Context extends OAuth2Request.Context { - constructor(token: string) { - super("POST"); - this.body.set("token", token); - } - - public setTokenTypeHint( - tokenType: "access_token" | "refresh_token", - ): void { - if (tokenType === "access_token") { - this.body.set("token_type_hint", "access_token"); - } else if (tokenType === "refresh_token") { - this.body.set("token_type_hint", "refresh_token"); - } - } - } - - export async function send( - endpoint: URLConstructor, - context: OAuth2Request.Context, - options?: { signal?: AbortSignal }, - ) { - let request = context.toRequest(endpoint); - let response = await fetch(request, { signal: options?.signal }); - let body = await response.json(); - - let result = new OAuth2RequestResult(body); - - if (result.hasErrorCode()) { - throw new OAuth2Request.Error( - result.errorCode(), - request, - context, - response.headers, - { - description: result.hasErrorDescription() - ? result.errorDescription() - : undefined, - uri: result.hasErrorURI() ? result.errorURI() : undefined, - }, - ); - } - } - } - - export namespace RefreshRequest { - export class Context extends OAuth2Request.Context { - constructor(refreshToken: string) { - super("POST"); - this.body.set("grant_type", "refresh_token"); - this.body.set("refresh_token", refreshToken); - } - - public addScopes(...scopes: string[]): void { - if (scopes.length < 1) { - return; - } - let scopeValue = scopes.join(" "); - const existingScopes = this.body.get("scope"); - if (existingScopes !== null) { - scopeValue = `${scopeValue} ${existingScopes}`; - } - this.body.set("scope", scopeValue); - } - } - } -} diff --git a/src/test/mock.ts b/src/test/mock.ts deleted file mode 100644 index 295556e..0000000 --- a/src/test/mock.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { http, HttpResponse } from "msw"; -import { setupServer } from "msw/node"; - -export const server = setupServer( - http.post("https://example.app/token", async () => { - return HttpResponse.json({ - access_token: "mocked", - expires_in: 3600, - refresh_token: "mocked", - scope: ["user:email", "user:profile"].join(" "), - token_type: "Bearer", - }); - }), -); diff --git a/tsconfig.json b/tsconfig.json index d04fd9d..daa6f12 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,6 +1,7 @@ { "extends": "@total-typescript/tsconfig/tsc/dom/library", - "include": ["src/index.ts", "src/lib/**/*.ts"], + "include": ["src/**/*.ts"], + "exclude": ["src/**/*.test.ts"], "compilerOptions": { "outDir": "./build" }