From 35845f3d24bc1dc934f07c59defefcc5c32c6571 Mon Sep 17 00:00:00 2001 From: Kacper Wojciechowski <39823706+jog1t@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:56:17 +0100 Subject: [PATCH] feat: add separate protocol for actor inspect --- .../actors/actors-actor-details.tsx | 6 +- .../actors/worker/actor-repl.worker.ts | 43 ++-- .../actors/worker/actor-worker-container.ts | 19 +- .../actors/worker/actor-worker-context.tsx | 7 +- .../actors/worker/actor-worker-schema.ts | 2 +- .../project/queries/actors/query-options.ts | 5 +- .../packages/components/src/ui/button.tsx | 6 +- sdks/actor/runtime/src/actor.ts | 241 ++++-------------- sdks/actor/runtime/src/connection.ts | 3 +- sdks/actor/runtime/src/errors.ts | 6 + sdks/actor/runtime/src/event.ts | 178 +++++++++++++ sdks/actor/runtime/src/inspect.ts | 194 ++++++++++++++ sdks/actor/runtime/src/log.ts | 6 + 13 files changed, 478 insertions(+), 238 deletions(-) create mode 100644 sdks/actor/runtime/src/event.ts create mode 100644 sdks/actor/runtime/src/inspect.ts diff --git a/frontend/apps/hub/src/domains/project/components/actors/actors-actor-details.tsx b/frontend/apps/hub/src/domains/project/components/actors/actors-actor-details.tsx index 4875325fb2..8f9be0cc8e 100644 --- a/frontend/apps/hub/src/domains/project/components/actors/actors-actor-details.tsx +++ b/frontend/apps/hub/src/domains/project/components/actors/actors-actor-details.tsx @@ -41,7 +41,11 @@ export function ActorsActorDetails({ const { data } = useSuspenseQuery(actorQueryOptions(props)); return ( - +
{ ); }; -interface InspectableActor { - internal_setState(rpc: this, state: Record): Promise; - internal_inspect(): Promise; -} - -let init: - | null - | ({ handle: ActorHandle } & InspectRpcResponse); +let init: null | ({ handle: ActorHandleRaw } & InspectRpcResponse); addEventListener("message", async (event) => { const { success, data } = MessageSchema.safeParse(event.data); @@ -130,11 +123,13 @@ addEventListener("message", async (event) => { } try { - const cl = new Client(data.managerUrl); - const handle = await Promise.race([ - cl.getWithId(data.actorId, {}), - wait(5000).then(() => undefined), - ]); + const handle = new ActorHandleRaw( + `${data.endpoint}/__inspect`, + undefined, + "cbor", + ); + + handle.connect(); if (!handle) { respond({ @@ -146,7 +141,7 @@ addEventListener("message", async (event) => { } const inspect = await Promise.race([ - handle.internal_inspect(), + handle.rpc<[], InspectRpcResponse>("inspect"), wait(5000).then(() => undefined), ]); @@ -208,21 +203,17 @@ addEventListener("message", async (event) => { data: formatted, }); - const rpcs = Object.fromEntries( - actor.rpcs.map( - (rpc) => - [ - rpc, - actor.handle[rpc as keyof typeof actor.handle], - ] as const, - ), + const exposedActor = Object.fromEntries( + init?.rpcs.map((rpc) => [ + rpc, + actor.handle.rpc.bind(actor.handle, rpc), + ]) ?? [], ); const evaluated = await evaluateCode(data.data, { console: createConsole(data.id), wait, - actor: actor.handle, - ...rpcs, + actor: exposedActor, }); return respond({ type: "result", @@ -250,7 +241,7 @@ addEventListener("message", async (event) => { try { const state = JSON.parse(data.data); - await actor.handle.internal_setState(state); + await actor.handle.rpc("setState", state); return respond({ type: "state-change", data: { diff --git a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-container.ts b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-container.ts index 9628122fdc..5090ebb722 100644 --- a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-container.ts +++ b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-container.ts @@ -1,7 +1,4 @@ -import { - actorManagerUrlQueryOptions, - actorQueryOptions, -} from "@/domains/project/queries"; +import { actorQueryOptions } from "@/domains/project/queries"; import { queryClient } from "@/queries/global"; import ActorWorker from "./actor-repl.worker?worker"; import { @@ -66,11 +63,13 @@ export class ActorWorkerContainer { projectNameId, environmentNameId, actorId, + endpoint, signal, }: { projectNameId: string; environmentNameId: string; actorId: string; + endpoint: string; signal: AbortSignal; }) { this.terminate(); @@ -79,16 +78,6 @@ export class ActorWorkerContainer { this.#state.status = { type: "pending" }; this.#update(); try { - // To check if the actor is supported, first we need to get the actor manager URL - // if there's no manager URL, we can't support the actor - const managerUrl = await queryClient.fetchQuery( - actorManagerUrlQueryOptions({ - projectNameId, - environmentNameId, - }), - ); - signal.throwIfAborted(); - // If we have the manager URL, next we need to check actor's runtime const { actor } = await queryClient.fetchQuery( actorQueryOptions({ @@ -118,7 +107,7 @@ export class ActorWorkerContainer { const worker = new ActorWorker({ name: `actor-${actorId}` }); signal.throwIfAborted(); // now worker needs to check if the actor is supported - this.#setupWorker(worker, { actorId, managerUrl }); + this.#setupWorker(worker, { actorId, endpoint }); signal.throwIfAborted(); return worker; } catch (e) { diff --git a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-context.tsx b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-context.tsx index 113a5c35aa..26bdfb2b7d 100644 --- a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-context.tsx +++ b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-context.tsx @@ -24,6 +24,7 @@ interface ActorWorkerContextProviderProps { actorId: string; projectNameId: string; environmentNameId: string; + endpoint?: string; enabled?: boolean; children: ReactNode; } @@ -32,6 +33,7 @@ export const ActorWorkerContextProvider = ({ children, actorId, enabled, + endpoint, projectNameId, environmentNameId, }: ActorWorkerContextProviderProps) => { @@ -43,11 +45,12 @@ export const ActorWorkerContextProvider = ({ useEffect(() => { const ctrl = new AbortController(); - if (enabled) { + if (enabled && endpoint) { container.init({ projectNameId, environmentNameId, actorId, + endpoint, signal: ctrl.signal, }); } @@ -56,7 +59,7 @@ export const ActorWorkerContextProvider = ({ ctrl.abort(); container.terminate(); }; - }, [actorId, projectNameId, environmentNameId, enabled]); + }, [actorId, projectNameId, environmentNameId, endpoint, enabled]); return ( diff --git a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-schema.ts b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-schema.ts index 090b0417da..0404a734c1 100644 --- a/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-schema.ts +++ b/frontend/apps/hub/src/domains/project/components/actors/worker/actor-worker-schema.ts @@ -14,7 +14,7 @@ const CodeMessageSchema = z.object({ }); const InitMessageSchema = z.object({ type: z.literal("init"), - managerUrl: z.string(), + endpoint: z.string(), actorId: z.string(), }); diff --git a/frontend/apps/hub/src/domains/project/queries/actors/query-options.ts b/frontend/apps/hub/src/domains/project/queries/actors/query-options.ts index 735a81c692..0f120186c6 100644 --- a/frontend/apps/hub/src/domains/project/queries/actors/query-options.ts +++ b/frontend/apps/hub/src/domains/project/queries/actors/query-options.ts @@ -151,6 +151,7 @@ export const actorQueryOptions = ({ (arg) => arg !== "", ), }, + endpoint: createActorEndpoint(data.actor.network), }), }); }; @@ -442,7 +443,9 @@ export const actorRegionQueryOptions = ({ }); }; -const createActorEndpoint = (network: Rivet.actor.Network) => { +const createActorEndpoint = ( + network: Rivet.actor.Network, +): string | undefined => { const http = Object.values(network.ports).find( (port) => port.protocol === "http" || port.protocol === "https", ); diff --git a/frontend/packages/components/src/ui/button.tsx b/frontend/packages/components/src/ui/button.tsx index 46273aff3a..75fba3a336 100644 --- a/frontend/packages/components/src/ui/button.tsx +++ b/frontend/packages/components/src/ui/button.tsx @@ -75,6 +75,8 @@ const Button = React.forwardRef( ) => { const C = asChild ? Slot : "button"; + const isIcon = size?.includes("icon"); + return ( ( ) : startIcon ? ( React.cloneElement(startIcon, startIcon.props) ) : null} - {!size?.includes("icon") && isLoading ? null : ( - {children} - )} + {isIcon && isLoading ? null : {children}} {endIcon ? React.cloneElement(endIcon, endIcon.props) : null} ); diff --git a/sdks/actor/runtime/src/actor.ts b/sdks/actor/runtime/src/actor.ts index 768cd07494..6173e26189 100644 --- a/sdks/actor/runtime/src/actor.ts +++ b/sdks/actor/runtime/src/actor.ts @@ -1,11 +1,9 @@ import { type Logger, setupLogging } from "@rivet-gg/actor-common/log"; import { listObjectMethods } from "@rivet-gg/actor-common/reflect"; -import { assertUnreachable, safeStringify } from "@rivet-gg/actor-common/utils"; import { isJsonSerializable } from "@rivet-gg/actor-common/utils"; import type { ActorContext, Metadata } from "@rivet-gg/actor-core"; import { ProtocolFormatSchema } from "@rivet-gg/actor-protocol/ws"; import type * as wsToClient from "@rivet-gg/actor-protocol/ws/to_client"; -import * as wsToServer from "@rivet-gg/actor-protocol/ws/to_server"; import { Hono, type Context as HonoContext } from "hono"; import { upgradeWebSocket } from "hono/deno"; import type { WSEvents } from "hono/ws"; @@ -14,14 +12,15 @@ import { type ActorConfig, mergeActorConfig } from "./config"; import { Connection, type ConnectionId, - type IncomingWebSocketMessage, type OutgoingWebSocketMessage, } from "./connection"; import * as errors from "./errors"; +import { handleMessageEvent } from "./event"; +import { ActorInspection } from "./inspect"; import type { Kv } from "./kv"; import { instanceLogger, logger } from "./log"; -import { Rpc } from "./rpc"; -import { Lock, deadline, throttle } from "./utils"; +import type { Rpc } from "./rpc"; +import { Lock, deadline } from "./utils"; const KEYS = { SCHEDULE: { @@ -92,6 +91,8 @@ export type ExtractActorConnState = A extends Actor< ? ConnState : never; +export type ExtractActorState = A extends Actor ? State : never; + /** * Abstract class representing a Rivet Actor. Extend this class to implement logic for your actor. * @@ -142,21 +143,16 @@ export abstract class Actor< #lastSaveTime = 0; #pendingSaveTimeout?: number | NodeJS.Timeout; - #notifyStateInspectThrottle = throttle(async () => { - const inspectionResult = this.internal_inspect(); - // TODO: Notify only inspector, not all clients - this._broadcast("_state-changed", inspectionResult.state); - }, 500); - - #notifyConnectionsInspectThrottle = throttle(async () => { - const inspectionResult = this.internal_inspect(); - // TODO: Notify only inspector, not all clients - this._broadcast("_connections-changed", inspectionResult.connections); - }, 500); - - #notifyEventsInspectThrottle = throttle(async (name: string) => { - this._broadcast("_event-emitted", { name }); - }, 100); + #inspection = new ActorInspection(this, { + state: () => ({ enabled: this.#stateEnabled, state: this.#stateProxy }), + connections: () => this.#connections.values(), + rpcs: () => this.#rpcNames, + setState: (state) => { + this.#validateStateEnabled(); + this.#setStateWithoutChange(state); + }, + onRpcCall: (ctx, rpc, args) => this.#executeRpc(ctx, rpc, args), + }); /** * This constructor should never be used directly. @@ -306,7 +302,7 @@ export abstract class Actor< throw new errors.InvalidStateType({ path }); } this.#stateChanged = true; - this.#notifyStateInspectThrottle(); + this.#inspection.notifyStateChanged(); // Call onStateChange if it exists if (this._onStateChange && this.#ready) { @@ -383,6 +379,24 @@ export abstract class Actor< //app.post("/rpc/:name", this.#pandleRpc.bind(this)); app.get("/connect", upgradeWebSocket(this.#handleWebSocket.bind(this))); + app.get( + "/__inspect/connect", + // cors({ + // // TODO: Fetch configuration from config, manager, or env? + // origin: (origin, c) => { + // return [ + // "http://localhost:5080", + // "https://hub.rivet.gg", + // ].includes(origin) || + // origin.endsWith(".rivet-hub-7jb.pages.dev") + // ? origin + // : "https://hub.rivet.gg"; + // }, + // }), + upgradeWebSocket((c) => + this.#inspection.handleWebsocketConnection(c), + ), + ); app.all("*", (c) => { return c.text("Not Found", 404); @@ -451,7 +465,7 @@ export abstract class Actor< // MARK: Events #addSubscription(eventName: string, connection: Connection) { connection.subscriptions.add(eventName); - this.#notifyConnectionsInspectThrottle(); + this.#inspection.notifyConnectionsChanged(); let subscribers = this.#eventSubscriptions.get(eventName); if (!subscribers) { @@ -463,7 +477,7 @@ export abstract class Actor< #removeSubscription(eventName: string, connection: Connection) { connection.subscriptions.delete(eventName); - this.#notifyConnectionsInspectThrottle(); + this.#inspection.notifyConnectionsChanged(); const subscribers = this.#eventSubscriptions.get(eventName); if (subscribers) { @@ -580,130 +594,23 @@ export abstract class Actor< return; } - let rpcRequestId: number | undefined; - try { - const value = - evt.data.valueOf() as IncomingWebSocketMessage; - - // Validate value length - let length: number; - if (typeof value === "string") { - length = value.length; - } else if (value instanceof Blob) { - length = value.size; - } else if ( - value instanceof ArrayBuffer || - value instanceof SharedArrayBuffer - ) { - length = value.byteLength; - } else { - assertUnreachable(value); - } - if (length > this.#config.protocol.maxIncomingMessageSize) { - throw new errors.MessageTooLong(); - } - - // Parse & validate message - const { - data: message, - success, - error, - } = wsToServer.ToServerSchema.safeParse( - await conn._parse(value), - ); - if (!success) { - throw new errors.MalformedMessage(error); - } - - if ("rr" in message.body) { - // RPC request - - const { - i: id, - n: name, - a: args = [], - } = message.body.rr; - - rpcRequestId = id; - - const ctx = new Rpc(conn); - const output = await this.#executeRpc(ctx, name, args); - - conn._sendWebSocketMessage( - conn._serialize({ - body: { - ro: { - i: id, - o: output, - }, - }, - } satisfies wsToClient.ToClient), - ); - } else if ("sr" in message.body) { - // Subscription request - - const { e: eventName, s: subscribe } = message.body.sr; - - if (subscribe) { - this.#addSubscription(eventName, conn); - } else { - this.#removeSubscription(eventName, conn); - } - } else { - assertUnreachable(message.body); - } - } catch (error) { - // Build response error information. Only return errors if flagged as public in order to prevent leaking internal behavior. - let code: string; - let message: string; - let metadata: unknown = undefined; - if (error instanceof errors.ActorError && error.public) { - logger().info("connection public error", { - rpc: rpcRequestId, - error, - }); - - code = error.code; - message = String(error); - metadata = error.metadata; - } else { + await handleMessageEvent(evt, conn, this.#config, { + onExecuteRpc: async (ctx, name, args) => { + return await this.#executeRpc(ctx, name, args); + }, + onSubscribe: async (eventName, conn) => { + this.#addSubscription(eventName, conn); + }, + onUnsubscribe: async (eventName, conn) => { + this.#removeSubscription(eventName, conn); + }, + onError: (error) => { logger().warn("connection internal error", { - rpc: rpcRequestId, + rpc: error.rpcRequestId, error, }); - - code = errors.INTERNAL_ERROR_CODE; - message = errors.INTERNAL_ERROR_DESCRIPTION; - } - - // Build response - if (rpcRequestId !== undefined) { - conn._sendWebSocketMessage( - conn._serialize({ - body: { - re: { - i: rpcRequestId, - c: code, - m: message, - md: metadata, - }, - }, - } satisfies wsToClient.ToClient), - ); - } else { - conn._sendWebSocketMessage( - conn._serialize({ - body: { - er: { - c: code, - m: message, - md: metadata, - }, - }, - } satisfies wsToClient.ToClient), - ); - } - } + }, + }); }, onClose: () => { if (!conn) { @@ -788,17 +695,6 @@ export abstract class Actor< ); } - /** - * Safely transforms the actor state into a string for debugging purposes. - */ - #inspectState(): string { - try { - return safeStringify(this.#stateRaw, 128 * 1024 * 1024); - } catch (error) { - return JSON.stringify({ _error: new errors.StateTooLarge() }); - } - } - // MARK: Lifecycle hooks /** * Hook called when the actor is first created. This method should return the initial state of the actor. The state can be access with `this._state`. @@ -872,7 +768,7 @@ export abstract class Actor< * @see {@link https://rivet.gg/docs/lifecycle|Lifecycle Documentation} */ protected _onConnect?(connection: Connection): void | Promise { - this.#notifyConnectionsInspectThrottle(); + this.#inspection.notifyConnectionsChanged(); } /** @@ -884,7 +780,7 @@ export abstract class Actor< protected _onDisconnect?( connection: Connection, ): void | Promise { - this.#notifyConnectionsInspectThrottle(); + this.#inspection.notifyConnectionsChanged(); } // MARK: Exposed methods @@ -1044,37 +940,6 @@ export abstract class Actor< } } - /** - * Public RPC method that inspects the actor's state and connections. - * @internal - * @returns The actor's state and connections. - */ - internal_inspect(): wsToClient.InspectRpcResponse { - return { - // Filter out internal 'inspect' RPC - rpcs: this.#rpcNames.filter( - (name) => !name.startsWith("internal_"), - ), - state: { - enabled: this.#stateEnabled, - native: this.#inspectState(), - }, - connections: [...this.#connections.values()].map((con) => - con._inspect(), - ), - }; - } - - /** - * Very insecure, but useful for debugging. This method allows you to set the actor's state directly. - * @internal - */ - internal_setState(_: Rpc, value: State) { - // FIXME: This should be only available to selected clients - this.#validateStateEnabled(); - this.#setStateWithoutChange(value); - } - /** * Shuts down the actor, closing all connections and stopping the server. * diff --git a/sdks/actor/runtime/src/connection.ts b/sdks/actor/runtime/src/connection.ts index 6aec5e3c20..2797537851 100644 --- a/sdks/actor/runtime/src/connection.ts +++ b/sdks/actor/runtime/src/connection.ts @@ -4,6 +4,7 @@ import * as cbor from "cbor-x"; import type { WSContext } from "hono/ws"; import type { AnyActor, ExtractActorConnState } from "./actor"; import * as errors from "./errors"; +import { INSPECT_SYMBOL } from "./inspect"; import { logger } from "./log"; import { assertUnreachable } from "./utils"; @@ -196,7 +197,7 @@ export class Connection { * Inspects the connection for debugging purposes. * @internal */ - public _inspect() { + [INSPECT_SYMBOL]() { return { id: this.id.toString(), subscriptions: [...this.subscriptions.values()], diff --git a/sdks/actor/runtime/src/errors.ts b/sdks/actor/runtime/src/errors.ts index 1d69c01ce8..3cfbb3be79 100644 --- a/sdks/actor/runtime/src/errors.ts +++ b/sdks/actor/runtime/src/errors.ts @@ -150,6 +150,12 @@ export class StateTooLarge extends ActorError { } } +export class Unsupported extends ActorError { + constructor(feature: string) { + super("unsupported", `Unsupported feature: ${feature}`); + } +} + /** * Options for the UserError class. */ diff --git a/sdks/actor/runtime/src/event.ts b/sdks/actor/runtime/src/event.ts new file mode 100644 index 0000000000..c1da51b868 --- /dev/null +++ b/sdks/actor/runtime/src/event.ts @@ -0,0 +1,178 @@ +import type * as wsToClient from "@rivet-gg/actor-protocol/ws/to_client"; +import * as wsToServer from "@rivet-gg/actor-protocol/ws/to_server"; +import type { WSMessageReceive } from "hono/ws"; +import type { AnyActor } from "./actor"; +import type { Connection, IncomingWebSocketMessage } from "./connection"; +import * as errors from "./errors"; +import { Rpc } from "./rpc"; +import { assertUnreachable } from "./utils"; + +interface MessageEventConfig { + protocol: { maxIncomingMessageSize: number }; +} + +export async function validateMessageEvent( + evt: MessageEvent, + connection: Connection, + config: MessageEventConfig, +) { + const value = evt.data.valueOf() as IncomingWebSocketMessage; + + // Validate value length + let length: number; + if (typeof value === "string") { + length = value.length; + } else if (value instanceof Blob) { + length = value.size; + } else if ( + value instanceof ArrayBuffer || + value instanceof SharedArrayBuffer + ) { + length = value.byteLength; + } else { + assertUnreachable(value); + } + if (length > config.protocol.maxIncomingMessageSize) { + throw new errors.MessageTooLong(); + } + + // Parse & validate message + const { + data: message, + success, + error, + } = wsToServer.ToServerSchema.safeParse(await connection._parse(value)); + + if (!success) { + throw new errors.MalformedMessage(error); + } + + return message; +} + +export async function handleMessageEvent( + event: MessageEvent, + conn: Connection, + config: MessageEventConfig, + handlers: { + onExecuteRpc?: ( + ctx: Rpc, + name: string, + args: unknown[], + ) => Promise; + onSubscribe?: (eventName: string, conn: Connection) => Promise; + onUnsubscribe?: ( + eventName: string, + conn: Connection, + ) => Promise; + onError: (error: { + code: string; + message: string; + metadata: unknown; + rpcRequestId?: number; + internal: boolean; + }) => void; + }, +) { + let rpcRequestId: number | undefined; + const message = await validateMessageEvent(event, conn, config); + + try { + if ("rr" in message.body) { + // RPC request + + if (handlers.onExecuteRpc === undefined) { + throw new errors.Unsupported("RPC"); + } + + const { i: id, n: name, a: args = [] } = message.body.rr; + + rpcRequestId = id; + + const ctx = new Rpc(conn); + const output = await handlers.onExecuteRpc?.(ctx, name, args); + + conn._sendWebSocketMessage( + conn._serialize({ + body: { + ro: { + i: id, + o: output, + }, + }, + } satisfies wsToClient.ToClient), + ); + } else if ("sr" in message.body) { + // Subscription request + + if ( + handlers.onSubscribe === undefined || + handlers.onUnsubscribe === undefined + ) { + throw new errors.Unsupported("Subscriptions"); + } + + const { e: eventName, s: subscribe } = message.body.sr; + + if (subscribe) { + await handlers.onSubscribe(eventName, conn); + return; + } + + await handlers.onUnsubscribe(eventName, conn); + } else { + assertUnreachable(message.body); + } + } catch (error) { + // Build response error information. Only return errors if flagged as public in order to prevent leaking internal behavior. + let code: string; + let message: string; + let metadata: unknown = undefined; + let internal = false; + if (error instanceof errors.ActorError && error.public) { + code = error.code; + message = String(error); + metadata = error.metadata; + } else { + code = errors.INTERNAL_ERROR_CODE; + message = errors.INTERNAL_ERROR_DESCRIPTION; + internal = true; + } + + // Build response + if (rpcRequestId !== undefined) { + conn._sendWebSocketMessage( + conn._serialize({ + body: { + re: { + i: rpcRequestId, + c: code, + m: message, + md: metadata, + }, + }, + } satisfies wsToClient.ToClient), + ); + } else { + conn._sendWebSocketMessage( + conn._serialize({ + body: { + er: { + c: code, + m: message, + md: metadata, + }, + }, + } satisfies wsToClient.ToClient), + ); + } + + handlers.onError({ + code, + message, + metadata, + rpcRequestId, + internal, + }); + } +} diff --git a/sdks/actor/runtime/src/inspect.ts b/sdks/actor/runtime/src/inspect.ts new file mode 100644 index 0000000000..38e5d8afbc --- /dev/null +++ b/sdks/actor/runtime/src/inspect.ts @@ -0,0 +1,194 @@ +import { safeStringify } from "@rivet-gg/actor-common/utils"; +import type * as wsToClient from "@rivet-gg/actor-protocol/ws/to_client"; +import type { Context } from "hono"; +import type { WSEvents } from "hono/ws"; +import type { AnyActor, ExtractActorState } from "./actor"; +import { mergeActorConfig } from "./config"; +import { Connection } from "./connection"; +import * as errors from "./errors"; +import { handleMessageEvent } from "./event"; +import { inspectLogger } from "./log"; +import type { Rpc } from "./rpc"; +import { throttle } from "./utils"; + +export const INSPECT_SYMBOL = Symbol("inspect"); + +export type ConnectionId = number; + +type SkipFirst = T extends [infer _, ...infer Rest] + ? Rest + : never; + +function connectionMap() { + let id: ConnectionId = 0; + const map = new Map>(); + return { + create: ( + ...params: SkipFirst>> + ) => { + const conId = id++; + const connection = new Connection(conId, ...params); + map.set(conId, connection); + return connection; + }, + delete: (conId: ConnectionId) => { + map.delete(conId); + }, + get: (conId: ConnectionId) => { + return map.get(conId); + }, + [Symbol.iterator]: () => map.values(), + get size() { + return map.size; + }, + }; +} + +interface InspectionAccessProxy { + connections: () => Iterable>; + state: () => { enabled: boolean; state: unknown }; + rpcs: () => string[]; + setState: (state: ExtractActorState) => Promise | void; + onRpcCall: (ctx: Rpc, rpc: string, args: unknown[]) => void; +} + +/** + * Thin compatibility layer for handling inspection access to an actor. + * @internal + */ +export class ActorInspection { + readonly #actor: A; + readonly #connections = connectionMap(); + readonly #proxy: InspectionAccessProxy; + + readonly #config = mergeActorConfig({}); + + readonly #logger = inspectLogger(); + + constructor(actor: A, proxy: InspectionAccessProxy) { + this.#actor = actor; + this.#proxy = proxy; + } + + readonly notifyStateChanged = throttle(async () => { + const inspectionResult = this.inspect(); + this.#broadcast("_state-changed", inspectionResult.state); + }, 500); + + readonly notifyConnectionsChanged = throttle(async () => { + const inspectionResult = this.inspect(); + this.#broadcast("_connections-changed", inspectionResult.connections); + }, 500); + + handleWebsocketConnection(c: Context): WSEvents { + // TODO: Compare hub version with protocol version + const protocolVersion = c.req.query("version"); + if (protocolVersion !== "1") { + this.#logger.warn("invalid protocol version", { + protocolVersion, + }); + throw new errors.InvalidProtocolVersion(protocolVersion); + } + + let connection: Connection | undefined; + return { + onOpen: (evt, ws) => { + connection = this.#connections.create( + ws, + "cbor", + undefined, + false, + ); + }, + onMessage: async (evt, ws) => { + if (!connection) { + this.#logger.warn("`connection` does not exist"); + return; + } + + await handleMessageEvent(evt, connection, this.#config, { + onExecuteRpc: async (ctx, name, args) => { + return await this.#executeRpc(ctx, name, args); + }, + onSubscribe: async () => { + // we do not support granular subscriptions + }, + onUnsubscribe: async () => { + // we do not support granular subscriptions + }, + onError: (error) => { + this.#logger.warn("connection error", { + rpc: error.rpcRequestId, + error, + }); + }, + }); + }, + onClose: () => { + if (!connection) { + this.#logger.warn("`connection` does not exist"); + return; + } + + this.#connections.delete(connection.id); + }, + onError: (error) => { + this.#logger.warn("inspect websocket error", { error }); + }, + }; + } + + #broadcast(event: string, ...args: unknown[]) { + if (this.#connections.size === 0) { + return; + } + + for (const connection of this.#connections) { + connection.send(event, ...args); + } + } + + /** + * Safely transforms the actor state into a string for debugging purposes. + */ + #inspectState(): string { + try { + return safeStringify(this.#proxy.state().state, 128 * 1024 * 1024); + } catch (error) { + return JSON.stringify({ _error: new errors.StateTooLarge() }); + } + } + + /** + * Public RPC method that inspects the actor's state and connections. + * @internal + * @returns The actor's state and connections. + */ + inspect(): wsToClient.InspectRpcResponse { + return { + // Filter out internal 'inspect' RPC + rpcs: this.#proxy.rpcs(), + state: { + enabled: this.#proxy.state().enabled, + native: this.#proxy.state().enabled ? this.#inspectState() : "", + }, + connections: [...this.#proxy.connections()].map((connection) => + connection[INSPECT_SYMBOL](), + ), + }; + } + + async #executeRpc(ctx: Rpc, name: string, args: unknown[]) { + if (name === "inspect") { + return this.inspect(); + } + + if (name === "setState") { + const state = args[0] as Record; + await this.#proxy.setState(state as ExtractActorState); + return; + } + + return this.#proxy.onRpcCall(ctx, name, args); + } +} diff --git a/sdks/actor/runtime/src/log.ts b/sdks/actor/runtime/src/log.ts index 70f964decb..91b470dc4a 100644 --- a/sdks/actor/runtime/src/log.ts +++ b/sdks/actor/runtime/src/log.ts @@ -6,6 +6,8 @@ export const RUNTIME_LOGGER_NAME = "actor-runtime"; /** Logger used for logs from the actor instance itself. */ export const ACTOR_LOGGER_NAME = "actor"; +export const INSPECT_LOGGER_NAME = "actor-inspect"; + export function logger() { return getLogger(RUNTIME_LOGGER_NAME); } @@ -13,3 +15,7 @@ export function logger() { export function instanceLogger() { return getLogger(ACTOR_LOGGER_NAME); } + +export function inspectLogger() { + return getLogger(INSPECT_LOGGER_NAME); +}