From e78b836d7f66ab13e0251f65cff07c6f712cca3b Mon Sep 17 00:00:00 2001 From: Hong Minhee Date: Thu, 11 Apr 2024 22:34:58 +0900 Subject: [PATCH] Authorized fetch for actor/collection dispatchers --- CHANGES.md | 14 +++- federation/callback.ts | 17 ++++ federation/handler.test.ts | 165 +++++++++++++++++++++++++++++++++++++ federation/handler.ts | 75 ++++++++--------- federation/middleware.ts | 89 +++++++++++++++++++- 5 files changed, 312 insertions(+), 48 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 9b2978b0..b2b14420 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,8 +10,18 @@ To be released. - Added `PUBLIC_COLLECTION` constant for [public addressing]. - - Added `RequestContext.getSignedKey()` method for [authorized fetch] - (also known as secure mode). + - `Federation` now supports [authorized fetch] for actor dispatcher and + collection dispatchers. + + - Added `ActorCallbackSetters.authorize()` method. + - Added `CollectionCallbackSetters.authorize()` method. + - Added `AuthorizedPredicate` type. + - Added `RequestContext.getSignedKey()` method. + - Added `FederationFetchOptions.onUnauthorized` option for handling + unauthorized fetches. + + - The default implementation of `FederationFetchOptions.onNotAcceptable` + option now responds with `Vary: Accept, Signature` header. [public addressing]: https://www.w3.org/TR/activitypub/#public-addressing [authorized fetch]: https://swicg.github.io/activitypub-http-signature/#authorized-fetch diff --git a/federation/callback.ts b/federation/callback.ts index 8b402647..272e87c4 100644 --- a/federation/callback.ts +++ b/federation/callback.ts @@ -98,3 +98,20 @@ export type OutboxErrorHandler = ( error: Error, activity: Activity | null, ) => void | Promise; + +/** + * A callback that determines if a request is authorized or not. + * + * @typeParam TContextData The context data to pass to the {@link Context}. + * @param context The request context. + * @param handle The handle of the actor that is being requested. + * @param signedKey The key that was used to sign the request, or `null` if + * the request was not signed or the signature was invalid. + * @returns `true` if the request is authorized, `false` otherwise. + * @since 0.7.0 + */ +export type AuthorizePredicate = ( + context: RequestContext, + handle: string, + signedKey: CryptographicKey | null, +) => boolean | Promise; diff --git a/federation/handler.test.ts b/federation/handler.test.ts index 10d244cf..98899be6 100644 --- a/federation/handler.test.ts +++ b/federation/handler.test.ts @@ -1,6 +1,7 @@ import { assert, assertEquals, assertFalse } from "@std/assert"; import { createRequestContext } from "../testing/context.ts"; import { mockDocumentLoader } from "../testing/docloader.ts"; +import { publicKey2 } from "../testing/keys.ts"; import { type Activity, Create, Note, Person } from "../vocab/vocab.ts"; import type { ActorDispatcher, @@ -71,6 +72,11 @@ Deno.test("handleActor()", async () => { onNotAcceptableCalled = request; return new Response("Not acceptable", { status: 406 }); }; + let onUnauthorizedCalled: Request | null = null; + const onUnauthorized = (request: Request) => { + onUnauthorizedCalled = request; + return new Response("Unauthorized", { status: 401 }); + }; let response = await handleActor( context.request, { @@ -78,11 +84,13 @@ Deno.test("handleActor()", async () => { handle: "someone", onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); onNotFoundCalled = null; response = await handleActor( @@ -93,11 +101,13 @@ Deno.test("handleActor()", async () => { actorDispatcher, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 406); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, context.request); + assertEquals(onUnauthorizedCalled, null); onNotAcceptableCalled = null; response = await handleActor( @@ -108,11 +118,13 @@ Deno.test("handleActor()", async () => { actorDispatcher, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); onNotFoundCalled = null; context = createRequestContext({ @@ -131,6 +143,7 @@ Deno.test("handleActor()", async () => { actorDispatcher, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 200); @@ -160,6 +173,7 @@ Deno.test("handleActor()", async () => { }); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); response = await handleActor( context.request, @@ -169,11 +183,85 @@ Deno.test("handleActor()", async () => { actorDispatcher, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); + + onNotFoundCalled = null; + context = createRequestContext({ + ...context, + request: new Request(context.url, { + headers: { + Accept: "application/activity+json", + }, + }), + }); + response = await handleActor( + context.request, + { + context, + handle: "someone", + actorDispatcher, + authorizePredicate: (_ctx, _handle, signedKey) => signedKey != null, + onNotFound, + onNotAcceptable, + onUnauthorized, + }, + ); + assertEquals(response.status, 401); + assertEquals(onNotFoundCalled, null); + assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, context.request); + + onUnauthorizedCalled = null; + context = createRequestContext({ + ...context, + getSignedKey: () => Promise.resolve(publicKey2), + }); + response = await handleActor( + context.request, + { + context, + handle: "someone", + actorDispatcher, + authorizePredicate: (_ctx, _handle, signedKey) => signedKey != null, + onNotFound, + onNotAcceptable, + onUnauthorized, + }, + ); + assertEquals(response.status, 200); + assertEquals( + response.headers.get("Content-Type"), + "application/activity+json", + ); + assertEquals(await response.json(), { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1", + { + manuallyApprovesFollowers: "as:manuallyApprovesFollowers", + discoverable: "toot:discoverable", + indexable: "toot:indexable", + memorial: "toot:memorial", + suspended: "toot:suspended", + toot: "http://joinmastodon.org/ns#", + schema: "http://schema.org#", + PropertyValue: "schema:PropertyValue", + value: "schema:value", + }, + ], + id: "https://example.com/users/someone", + type: "Person", + name: "Someone", + }); + assertEquals(onNotFoundCalled, null); + assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); }); Deno.test("handleCollection()", async () => { @@ -221,6 +309,11 @@ Deno.test("handleCollection()", async () => { onNotAcceptableCalled = request; return new Response("Not acceptable", { status: 406 }); }; + let onUnauthorizedCalled: Request | null = null; + const onUnauthorized = (request: Request) => { + onUnauthorizedCalled = request; + return new Response("Unauthorized", { status: 401 }); + }; let response = await handleCollection( context.request, { @@ -228,11 +321,13 @@ Deno.test("handleCollection()", async () => { handle: "someone", onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); onNotFoundCalled = null; response = await handleCollection( @@ -243,11 +338,13 @@ Deno.test("handleCollection()", async () => { collectionCallbacks: { dispatcher }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 406); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, context.request); + assertEquals(onUnauthorizedCalled, null); onNotAcceptableCalled = null; response = await handleCollection( @@ -258,11 +355,13 @@ Deno.test("handleCollection()", async () => { collectionCallbacks: { dispatcher }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); onNotFoundCalled = null; context = createRequestContext({ @@ -281,11 +380,13 @@ Deno.test("handleCollection()", async () => { collectionCallbacks: { dispatcher }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 404); assertEquals(onNotFoundCalled, context.request); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); onNotFoundCalled = null; response = await handleCollection( @@ -296,6 +397,63 @@ Deno.test("handleCollection()", async () => { collectionCallbacks: { dispatcher }, onNotFound, onNotAcceptable, + onUnauthorized, + }, + ); + assertEquals(response.status, 200); + assertEquals( + response.headers.get("Content-Type"), + "application/activity+json", + ); + assertEquals(await response.json(), { + "@context": "https://www.w3.org/ns/activitystreams", + type: "OrderedCollection", + items: [ + { type: "Create", id: "https://example.com/activities/1" }, + { type: "Create", id: "https://example.com/activities/2" }, + { type: "Create", id: "https://example.com/activities/3" }, + ], + }); + assertEquals(onNotFoundCalled, null); + assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); + + response = await handleCollection( + context.request, + { + context, + handle: "someone", + collectionCallbacks: { + dispatcher, + authorizePredicate: (_ctx, _handle, key) => key != null, + }, + onNotFound, + onNotAcceptable, + onUnauthorized, + }, + ); + assertEquals(response.status, 401); + assertEquals(onNotFoundCalled, null); + assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, context.request); + + onUnauthorizedCalled = null; + context = createRequestContext({ + ...context, + getSignedKey: () => Promise.resolve(publicKey2), + }); + response = await handleCollection( + context.request, + { + context, + handle: "someone", + collectionCallbacks: { + dispatcher, + authorizePredicate: (_ctx, _handle, key) => key != null, + }, + onNotFound, + onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 200); @@ -314,6 +472,7 @@ Deno.test("handleCollection()", async () => { }); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); response = await handleCollection( context.request, @@ -328,6 +487,7 @@ Deno.test("handleCollection()", async () => { }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 200); @@ -344,6 +504,7 @@ Deno.test("handleCollection()", async () => { }); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); let url = new URL("https://example.com/?cursor=0"); context = createRequestContext({ @@ -368,6 +529,7 @@ Deno.test("handleCollection()", async () => { }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 200); @@ -387,6 +549,7 @@ Deno.test("handleCollection()", async () => { }); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); url = new URL("https://example.com/?cursor=2"); context = createRequestContext({ @@ -411,6 +574,7 @@ Deno.test("handleCollection()", async () => { }, onNotFound, onNotAcceptable, + onUnauthorized, }, ); assertEquals(response.status, 200); @@ -430,6 +594,7 @@ Deno.test("handleCollection()", async () => { }); assertEquals(onNotFoundCalled, null); assertEquals(onNotAcceptableCalled, null); + assertEquals(onUnauthorizedCalled, null); }); Deno.test("respondWithObject()", async () => { diff --git a/federation/handler.ts b/federation/handler.ts index d674634b..bce9c3e9 100644 --- a/federation/handler.ts +++ b/federation/handler.ts @@ -11,6 +11,7 @@ import { } from "../vocab/vocab.ts"; import type { ActorDispatcher, + AuthorizePredicate, CollectionCounter, CollectionCursor, CollectionDispatcher, @@ -35,6 +36,8 @@ export interface ActorHandlerParameters { handle: string; context: RequestContext; actorDispatcher?: ActorDispatcher; + authorizePredicate?: AuthorizePredicate; + onUnauthorized(request: Request): Response | Promise; onNotFound(request: Request): Response | Promise; onNotAcceptable(request: Request): Response | Promise; } @@ -45,8 +48,10 @@ export async function handleActor( handle, context, actorDispatcher, + authorizePredicate, onNotFound, onNotAcceptable, + onUnauthorized, }: ActorHandlerParameters, ): Promise { if (actorDispatcher == null) { @@ -55,13 +60,13 @@ export async function handleActor( } const key = await context.getActorKey(handle); const actor = await actorDispatcher(context, handle, key); - if (actor == null) { - const response = onNotFound(request); - return response instanceof Promise ? await response : response; - } - if (!acceptsJsonLd(request)) { - const response = onNotAcceptable(request); - return response instanceof Promise ? await response : response; + if (actor == null) return await onNotFound(request); + if (!acceptsJsonLd(request)) return await onNotAcceptable(request); + if (authorizePredicate != null) { + const key = await context.getSignedKey(); + if (!await authorizePredicate(context, handle, key)) { + return await onUnauthorized(request); + } } const jsonLd = await actor.toJsonLd(context); return new Response(JSON.stringify(jsonLd), { @@ -95,12 +100,18 @@ export interface CollectionCallbacks { * A callback that returns the last cursor for a collection. */ lastCursor?: CollectionCursor; + + /** + * A callback that determines if a request is authorized to access the collection. + */ + authorizePredicate?: AuthorizePredicate; } export interface CollectionHandlerParameters { handle: string; context: RequestContext; collectionCallbacks?: CollectionCallbacks; + onUnauthorized(request: Request): Response | Promise; onNotFound(request: Request): Response | Promise; onNotAcceptable(request: Request): Response | Promise; } @@ -114,51 +125,34 @@ export async function handleCollection< handle, context, collectionCallbacks, + onUnauthorized, onNotFound, onNotAcceptable, }: CollectionHandlerParameters, ): Promise { - if (collectionCallbacks == null) { - const response = onNotFound(request); - return response instanceof Promise ? await response : response; - } + if (collectionCallbacks == null) return await onNotFound(request); const url = new URL(request.url); const cursor = url.searchParams.get("cursor"); let collection: OrderedCollection | OrderedCollectionPage; if (cursor == null) { - const firstCursorPromise = collectionCallbacks.firstCursor?.( + const firstCursor = await collectionCallbacks.firstCursor?.( context, handle, ); - const firstCursor = firstCursorPromise instanceof Promise - ? await firstCursorPromise - : firstCursorPromise; - const totalItemsPromise = collectionCallbacks.counter?.(context, handle); - const totalItems = totalItemsPromise instanceof Promise - ? await totalItemsPromise - : totalItemsPromise; + const totalItems = await collectionCallbacks.counter?.(context, handle); if (firstCursor == null) { - const pagePromise = collectionCallbacks.dispatcher(context, handle, null); - const page = pagePromise instanceof Promise - ? await pagePromise - : pagePromise; - if (page == null) { - const response = onNotFound(request); - return response instanceof Promise ? await response : response; - } + const page = await collectionCallbacks.dispatcher(context, handle, null); + if (page == null) return await onNotFound(request); const { items } = page; collection = new OrderedCollection({ totalItems: totalItems == null ? null : Number(totalItems), items, }); } else { - const lastCursorPromise = collectionCallbacks.lastCursor?.( + const lastCursor = await collectionCallbacks.lastCursor?.( context, handle, ); - const lastCursor = lastCursorPromise instanceof Promise - ? await lastCursorPromise - : lastCursorPromise; const first = new URL(context.url); first.searchParams.set("cursor", firstCursor); let last = null; @@ -173,14 +167,8 @@ export async function handleCollection< }); } } else { - const pagePromise = collectionCallbacks.dispatcher(context, handle, cursor); - const page = pagePromise instanceof Promise - ? await pagePromise - : pagePromise; - if (page == null) { - const response = onNotFound(request); - return response instanceof Promise ? await response : response; - } + const page = await collectionCallbacks.dispatcher(context, handle, cursor); + if (page == null) return await onNotFound(request); const { items, prevCursor, nextCursor } = page; let prev = null; if (prevCursor != null) { @@ -196,9 +184,12 @@ export async function handleCollection< partOf.searchParams.delete("cursor"); collection = new OrderedCollectionPage({ prev, next, items, partOf }); } - if (!acceptsJsonLd(request)) { - const response = onNotAcceptable(request); - return response instanceof Promise ? await response : response; + if (!acceptsJsonLd(request)) return await onNotAcceptable(request); + if (collectionCallbacks.authorizePredicate != null) { + const key = await context.getSignedKey(); + if (!await collectionCallbacks.authorizePredicate(context, handle, key)) { + return await onUnauthorized(request); + } } const jsonLd = await collection.toJsonLd(context); return new Response(JSON.stringify(jsonLd), { diff --git a/federation/middleware.ts b/federation/middleware.ts index 462255a8..2f9eff00 100644 --- a/federation/middleware.ts +++ b/federation/middleware.ts @@ -15,6 +15,7 @@ import { handleWebFinger } from "../webfinger/handler.ts"; import type { ActorDispatcher, ActorKeyPairDispatcher, + AuthorizePredicate, CollectionCounter, CollectionCursor, CollectionDispatcher, @@ -470,12 +471,14 @@ export class Federation { const callbacks: ActorCallbacks = { dispatcher }; this.#actorCallbacks = callbacks; const setters: ActorCallbackSetters = { - setKeyPairDispatcher: ( - dispatcher: ActorKeyPairDispatcher, - ) => { + setKeyPairDispatcher(dispatcher: ActorKeyPairDispatcher) { callbacks.keyPairDispatcher = dispatcher; return setters; }, + authorize(predicate: AuthorizePredicate) { + callbacks.authorizePredicate = predicate; + return setters; + }, }; return setters; } @@ -533,6 +536,10 @@ export class Federation { callbacks.lastCursor = cursor; return setters; }, + authorize(predicate: AuthorizePredicate) { + callbacks.authorizePredicate = predicate; + return setters; + }, }; return setters; } @@ -578,6 +585,10 @@ export class Federation { callbacks.lastCursor = cursor; return setters; }, + authorize(predicate: AuthorizePredicate) { + callbacks.authorizePredicate = predicate; + return setters; + }, }; return setters; } @@ -623,6 +634,10 @@ export class Federation { callbacks.lastCursor = cursor; return setters; }, + authorize(predicate: AuthorizePredicate) { + callbacks.authorizePredicate = predicate; + return setters; + }, }; return setters; } @@ -801,11 +816,13 @@ export class Federation { { onNotFound, onNotAcceptable, + onUnauthorized, contextData, }: FederationFetchOptions, ): Promise { onNotFound ??= notFound; onNotAcceptable ??= notAcceptable; + onUnauthorized ??= unauthorized; const url = new URL(request.url); const route = this.#router.route(url.pathname); if (route == null) { @@ -832,6 +849,8 @@ export class Federation { handle: route.values.handle, context, actorDispatcher: this.#actorCallbacks?.dispatcher, + authorizePredicate: this.#actorCallbacks?.authorizePredicate, + onUnauthorized, onNotFound, onNotAcceptable, }); @@ -840,6 +859,7 @@ export class Federation { handle: route.values.handle, context, collectionCallbacks: this.#outboxCallbacks, + onUnauthorized, onNotFound, onNotAcceptable, }); @@ -867,6 +887,7 @@ export class Federation { handle: route.values.handle, context, collectionCallbacks: this.#followingCallbacks, + onUnauthorized, onNotFound, onNotAcceptable, }); @@ -875,6 +896,7 @@ export class Federation { handle: route.values.handle, context, collectionCallbacks: this.#followersCallbacks, + onUnauthorized, onNotFound, onNotAcceptable, }); @@ -913,11 +935,21 @@ export interface FederationFetchOptions { * @returns The response to the request. */ onNotAcceptable?: (request: Request) => Response | Promise; + + /** + * A callback to handle a request when the request is unauthorized. + * If not provided, a 401 response is returned. + * @param request The request object. + * @returns The response to the request. + * @since 0.7.0 + */ + onUnauthorized?: (request: Request) => Response | Promise; } interface ActorCallbacks { dispatcher?: ActorDispatcher; keyPairDispatcher?: ActorKeyPairDispatcher; + authorizePredicate?: AuthorizePredicate; } /** @@ -942,23 +974,58 @@ export interface ActorCallbackSetters { setKeyPairDispatcher( dispatcher: ActorKeyPairDispatcher, ): ActorCallbackSetters; + + /** + * Specifies the conditions under which requests are authorized. + * @param predicate A callback that returns whether a request is authorized. + * @returns The setters object so that settings can be chained. + * @since 0.7.0 + */ + authorize( + predicate: AuthorizePredicate, + ): ActorCallbackSetters; } /** * Additional settings for a collection dispatcher. */ export interface CollectionCallbackSetters { + /** + * Sets the counter for the collection. + * @param counter A callback that returns the number of items in the collection. + * @returns The setters object so that settings can be chained. + */ setCounter( counter: CollectionCounter, ): CollectionCallbackSetters; + /** + * Sets the first cursor for the collection. + * @param cursor The cursor for the first item in the collection. + * @returns The setters object so that settings can be chained. + */ setFirstCursor( cursor: CollectionCursor, ): CollectionCallbackSetters; + /** + * Sets the last cursor for the collection. + * @param cursor The cursor for the last item in the collection. + * @returns The setters object so that settings can be chained. + */ setLastCursor( cursor: CollectionCursor, ): CollectionCallbackSetters; + + /** + * Specifies the conditions under which requests are authorized. + * @param predicate A callback that returns whether a request is authorized. + * @returns The setters object so that settings can be chained. + * @since 0.7.0 + */ + authorize( + predicate: AuthorizePredicate, + ): CollectionCallbackSetters; } /** @@ -995,5 +1062,19 @@ function notFound(_request: Request): Response { } function notAcceptable(_request: Request): Response { - return new Response("Not Acceptable", { status: 406 }); + return new Response("Not Acceptable", { + status: 406, + headers: { + Vary: "Accept, Signature", + }, + }); +} + +function unauthorized(_request: Request): Response { + return new Response("Unauthorized", { + status: 401, + headers: { + Vary: "Accept, Signature", + }, + }); }