diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 19396a4aaa..4d7af2b238 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,7 +36,7 @@ jobs: - name: Setup bazel uses: jwlawson/actions-setup-bazel@v1 with: - bazel-version: '3.7.2' + bazel-version: "3.7.2" - uses: maxim-lobanov/setup-xcode@v1 if: runner.os == 'macOS' with: @@ -148,7 +148,7 @@ jobs: - name: Setup Rust run: rustup show - name: check @hotg-ai/rune - run: yarn install && yarn build && yarn test + run: yarn install && yarn ci working-directory: bindings/web/rune - name: check @hotg-ai/rune-tfjs-v3 run: yarn install && yarn build && yarn test @@ -159,4 +159,3 @@ jobs: - name: check @hotg-ai/rune-tflite run: yarn install && yarn build working-directory: bindings/web/tflite - diff --git a/.gitignore b/.gitignore index 3a2972c887..40223a01bd 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ examples/.DS_Store tarpaulin-report.html .ipynb_checkpoints/ build/ + +.parcel-cache/ diff --git a/bindings/web/rune/LICENSE_APACHE.md b/bindings/web/rune/LICENSE_APACHE.md index 1b5ec8b78e..038d25d682 100644 --- a/bindings/web/rune/LICENSE_APACHE.md +++ b/bindings/web/rune/LICENSE_APACHE.md @@ -92,33 +92,33 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION meet the following conditions: (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and + Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and + stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions diff --git a/bindings/web/rune/README.md b/bindings/web/rune/README.md index 278f85e91a..90f0e285b7 100644 --- a/bindings/web/rune/README.md +++ b/bindings/web/rune/README.md @@ -4,8 +4,8 @@ A package that lets you run Runes in the browser. ## Getting Started -The easiest way to get started is by following [*Lesson 4: Integrating With The -Browser*][lesson-4] from our tutorial series. +The easiest way to get started is by following [_Lesson 4: Integrating With The +Browser_][lesson-4] from our tutorial series. This will walk you through creating a React application which initializes the Rune runtime and executes it every time a button is pressed. @@ -25,16 +25,16 @@ import path (i.e. you import from `@hotg-ai/rune/builtins` instead of `@hotg-ai/rune/dist/builtins`). As a precaution, the `package.json` in this folder sets `"private": true` to -make sure you don't accidentally run `yarn publish` +make sure you don't accidentally run `yarn publish` ## License This project is licensed under either of - * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE.md) or - http://www.apache.org/licenses/LICENSE-2.0) - * MIT license ([LICENSE-MIT](LICENSE-MIT.md) or - http://opensource.org/licenses/MIT) +- Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE.md) or + http://www.apache.org/licenses/LICENSE-2.0) +- MIT license ([LICENSE-MIT](LICENSE-MIT.md) or + http://opensource.org/licenses/MIT) at your option. diff --git a/bindings/web/rune/package.json b/bindings/web/rune/package.json index 23a74a5d11..74530bb1ab 100644 --- a/bindings/web/rune/package.json +++ b/bindings/web/rune/package.json @@ -1,27 +1,43 @@ { "name": "@hotg-ai/rune", - "version": "0.11.8", + "version": "0.12.0-rc.1", "description": "Execute Runes inside a JavaScript environment.", "repository": "https://github.com/hotg-ai/rune", "homepage": "https://hotg.dev/", "author": "The Rune Developers ", "license": "MIT OR Apache-2.0", - "main": "index.js", - "types": "index.d.ts", - "private": true, + "source": "src/index.ts", + "module": "dist/index.js", + "types": "dist/index.d.ts", "scripts": { - "build": "tsc", - "watch": "tsc --watch", + "build": "parcel build", + "watch": "parcel watch", "test": "jest", + "ci": "tsc --noEmit && parcel build && yarn test", "fmt": "prettier --write .", - "release": "tsc && cd dist && cp ../*.md . && sed 's/\"private\": true,/\"private\": false,/g' ../package.json > package.json && yarn publish", "generate-runefile-types": "json2ts ../../../crates/compiler/runefile-schema.json --output src/Runefile.ts" }, - "dependencies": {}, + "files": [ + "dist/*", + "*.md", + "*.json", + "jest.config.ts" + ], + "dependencies": { + "@hotg-ai/rune": "^0.12.0-rc.0", + "@hotg-ai/rune-wit-files": "^0.3.1", + "js-yaml": "^4.1.0", + "jszip": "^3.9.1", + "pino": "^7.11.0" + }, "devDependencies": { + "@parcel/packager-ts": "2.6.0", + "@parcel/transformer-typescript-types": "^2.6.0", "@types/jest": "^27.0.0", + "@types/js-yaml": "^4.0.5", "jest": "^27.0.6", "json-schema-to-typescript": "^10.1.5", + "parcel": "^2.5.0", "prettier": "^2.5.1", "ts-jest": "^27.0.4", "ts-node": "^10.4.0", diff --git a/bindings/web/rune/src/Rune.ts b/bindings/web/rune/src/Rune.ts new file mode 100644 index 0000000000..d13bf119a8 --- /dev/null +++ b/bindings/web/rune/src/Rune.ts @@ -0,0 +1,40 @@ +import { Logger, pino } from "pino"; +import type { ModelHandler, Runtime } from "."; +import { RuneLoader } from "./RuneLoader"; + +/** + * A builder object that lets you configure how a Rune is loaded. + */ +export class Rune { + private modelHandlers: Record = {}; + private logger: Logger = pino({ level: "silent", enabled: false }); + + /** + * Set the logger that will be used during the loading process and by the + * Rune runtime. + */ + public withLogger(logger: Logger): this { + this.logger = logger; + return this; + } + + /** + * Register a model handler based on the "model-format" argument attached to + * a model node. + */ + public withModelHandler(modelType: string, handler: ModelHandler): this { + this.modelHandlers[modelType] = handler; + return this; + } + + /** + * Load the Rune, instantiating a Runtime that can be used to interact with + * it. + * + * @param rune + */ + public async load(rune: Uint8Array): Promise { + const loader = new RuneLoader(this.modelHandlers, this.logger); + return await loader.load(rune); + } +} diff --git a/bindings/web/rune/src/RuneLoader.ts b/bindings/web/rune/src/RuneLoader.ts new file mode 100644 index 0000000000..1880523815 --- /dev/null +++ b/bindings/web/rune/src/RuneLoader.ts @@ -0,0 +1,192 @@ +import JSZip from "jszip"; +import yaml from "js-yaml"; +import { Logger, pino } from "pino"; +import type { ModelHandler, Node, Runtime, Tensor } from "."; +import { + CapabilityStage, + DocumentV1, + ModelStage, + OutStage, + ProcBlockStage, +} from "./Runefile"; +import { ProcBlock } from "./proc_blocks"; +import { create } from "./Runtime"; +import { + isCapabilityStage, + isModelStage, + isOutStage, + isProcBlockStage, + isRunefile, + stageArguments, +} from "./utils"; + +export class RuneLoader { + logger: Logger; + + constructor( + private modelHandlers: Record, + private rootLogger: Logger + ) { + this.logger = rootLogger.child({ name: "RuneLoader" }); + } + + async load(rune: Uint8Array): Promise { + this.logger.info({ bytes: rune.byteLength }, "Loading the Rune"); + + const zip = new JSZip(); + await zip.loadAsync(rune); + const runefile = await this.parseRunefile(zip); + + const nodes = splitByStageType(runefile); + const procBlocks = await this.instantiateProcBlocks(nodes, zip); + const models = await this.loadModels(nodes.model, zip, this.modelHandlers); + + return await create(runefile, procBlocks, models, this.rootLogger); + } + + async parseRunefile(zip: JSZip): Promise { + const f = zip.file("Runefile.yml"); + if (!f) { + throw new Error("No Runefile.yml found"); + } + const src = await f.async("string"); + const runefile = yaml.load(src); + + if (!isRunefile(runefile)) { + throw new Error("Invalid Runefile"); + } + + this.logger.debug({ length: src.length }, "Parsed the Runefile"); + + return runefile; + } + + async instantiateProcBlocks( + stages: Stages, + zip: JSZip + ): Promise> { + const start = Date.now(); + + const entries = stagesBackedByProcBlocks(stages).map( + async ({ name, path }) => { + this.logger.debug({ procBlock: name, path }, "Reading proc-block"); + + const file = zip.file(path); + + if (!file) { + throw new Error(`The Rune doesn't contain "${path}"`); + } + + const data = await file.async("arraybuffer"); + const procBlock = await ProcBlock.load( + data, + this.rootLogger.child({ procBlock: name }) + ); + return [name, procBlock] as const; + } + ); + + const procBlocks = Object.fromEntries(await Promise.all(entries)); + + this.logger.debug( + { + count: Object.keys(procBlocks).length, + durationMs: Date.now() - start, + }, + "Finished instantiating all proc-blocks" + ); + + return procBlocks; + } + + async loadModels( + stages: Record, + zip: JSZip, + modelHandlers: Record + ): Promise> { + const start = Date.now(); + + const promises = Object.entries(stages).map(async ([name, stage]) => { + const format = stage.args?.["model-format"] || "tensorflow-lite"; + const filename = stage.model; + this.logger.debug({ model: name, format, filename }, "Loading model"); + + const file = zip.file(filename); + + if (!file) { + throw new Error(`The Rune doesn't contain "${filename}"`); + } + + if (!(format in modelHandlers)) { + throw new Error( + `No handler was registered for the "${format}" model on the "${name}" node` + ); + } + + const handler = modelHandlers[format]; + + const data = await file.async("arraybuffer"); + const model = await handler(data, stageArguments(stage), this.rootLogger); + + this.logger.debug( + { model: name, length: data.byteLength }, + "Loaded model" + ); + + return [name, model]; + }); + + const models = Object.fromEntries(await Promise.all(promises)); + + this.logger.debug( + { + count: Object.keys(models).length, + durationMs: Date.now() - start, + }, + "Finished instantiating all models" + ); + + return models; + } +} + +type Stages = { + capability: Record; + procBlock: Record; + model: Record; + out: Record; +}; + +function splitByStageType(runefile: DocumentV1): Stages { + const nodes: Stages = { capability: {}, procBlock: {}, model: {}, out: {} }; + + for (const [name, stage] of Object.entries(runefile.pipeline)) { + if (isProcBlockStage(stage)) { + nodes.procBlock[name] = stage; + } else if (isModelStage(stage)) { + nodes.model[name] = stage; + } else if (isCapabilityStage(stage)) { + nodes.capability[name] = stage; + } else if (isOutStage(stage)) { + nodes.out[name] = stage; + } + } + + return nodes; +} + +function stagesBackedByProcBlocks(stages: Stages) { + const procBlocks: Array<{ name: string; path: string }> = []; + + for (const [name, stage] of Object.entries(stages.procBlock)) { + const path = stage["proc-block"]; + procBlocks.push({ name, path }); + } + + for (const [name, stage] of Object.entries(stages.capability)) { + const path = stage.capability; + procBlocks.push({ name, path }); + } + + return procBlocks; +} diff --git a/bindings/web/rune/src/Runtime.test.ts b/bindings/web/rune/src/Runtime.test.ts index ce4fab193d..a7a66a7e8c 100644 --- a/bindings/web/rune/src/Runtime.test.ts +++ b/bindings/web/rune/src/Runtime.test.ts @@ -1,97 +1,114 @@ -import child_process from "child_process"; -import path from "path"; -import fs from "fs"; -import { Runtime, Capability, Output } from "./Runtime"; - -const decoder = new TextDecoder("utf8"); +import yaml from "js-yaml"; +import { TensorDescriptor, Tensors } from "./proc_blocks"; +import { DocumentV1 } from "./Runefile"; +import { Node } from "."; +import { Runtime, create } from "./Runtime"; +import { ElementType, Tensor } from "."; +import { floatTensor } from "./utils"; +import { testLogger } from "./__test__"; describe("Runtime", () => { - const noopRune = buildExample("noop"); - - it("can load the noop Rune", async () => { - const imports = { - createCapability: () => new RawCapability(), - createOutput: () => new SpyOutput([]), - createModel: () => { throw new Error(); }, - log: (msg: any) => { }, - }; - - const runtime = await Runtime.load(noopRune, imports); - - expect(runtime).not.toBeNull(); - }); - - it("can run the noop Rune", async () => { - const calls: Uint8Array[] = []; - const imports = { - createCapability: () => new RawCapability([ - 1, 0, 0, 0, - 2, 0, 0, 0, - 3, 0, 0, 0, - 4, 0, 0, 0, - ]), - createOutput: () => new SpyOutput(calls), - createModel: () => { throw new Error(); }, - log: (msg: any) => { }, - }; - const runtime = await Runtime.load(noopRune, imports); - - runtime.call(); - - expect(calls).toHaveLength(1); - const output = decoder.decode(calls[0]); - expect(JSON.parse(output)).toEqual({ - channel: 1, - dimensions: [4,], - elements: [1, 2, 3, 4,], - type_name: "i32", - }); + let logger = testLogger(); + + const src = ` + version: 1 + image: runicos/base + pipeline: + rand: + capability: RAW + outputs: + - type: F32 + dimensions: + - 1 + - 1 + args: + length: "4" + mod360: + proc-block: proc_blocks/mod360 + inputs: + - rand + outputs: + - type: F32 + dimensions: + - 1 + - 1 + args: + modulus: "360" + sine: + model: models/sine + inputs: + - mod360 + outputs: + - type: F32 + dimensions: + - 1 + - 1 + serial: + out: serial + inputs: + - sine + resources: {}`; + const runefile = yaml.load(src) as DocumentV1; + + const f32_1x1 = { + elementType: ElementType.F32, + dimensions: { tag: "fixed", val: Uint32Array.from([1]) }, + } as const; + + const rand = dummyProcBlock([], [{ name: "output", ...f32_1x1 }], { + output: floatTensor([1]), + }); + const mod360 = dummyProcBlock( + [{ name: "input", ...f32_1x1 }], + [{ name: "output", ...f32_1x1 }], + { output: floatTensor([2]) } + ); + + const sine = dummyNode( + [{ name: "input", ...f32_1x1 }], + [{ name: "output", ...f32_1x1 }], + { output: floatTensor([3]) } + ); + + it("can run the sine Rune", async () => { + const procBlocks = { rand, mod360 }; + const models = { sine }; + + const runtime: Runtime = create(runefile, procBlocks, models, logger); + + runtime.setInput("rand", floatTensor([0])); + + await runtime.infer(); + + const outputs = runtime.outputs; + expect(outputs).toMatchObject({ + serial: [floatTensor([3])], }); + }); }); -class RawCapability implements Capability { - data: Uint8Array = new Uint8Array(); - - constructor(data?: number[]) { - if (data) { - this.data = Uint8Array.from(data); - } - } - - setParameter(name: string, value: number): void { - throw new Error("Method not implemented."); - } - generate(dest: Uint8Array): void { - dest.set(this.data); - } -} - -class SpyOutput implements Output { - received: Uint8Array[]; - constructor(received: Uint8Array[]) { - this.received = received; - } - - consume(data: Uint8Array): void { - this.received.push(data); - } +function dummyProcBlock( + inputs: TensorDescriptor[], + outputs: TensorDescriptor[], + results: Record +) { + return { + graph: (): Tensors => { + return { inputs, outputs }; + }, + evaluate: () => results, + }; } -function buildExample(name: string): ArrayBuffer { - const gitOutput = child_process.execSync("git rev-parse --show-toplevel"); - const repoRoot = decoder.decode(gitOutput).trim(); - - const exampleDir = path.join(repoRoot, "examples", name); - const runefile = path.join(exampleDir, "Runefile.yml"); - - child_process.execSync(`cargo run --bin rune --quiet -- build ${runefile} --quiet --unstable --rune-repo-dir ${repoRoot}`, { - cwd: repoRoot, - env: { - RUST_LOG: "warning", - ...process.env - }, - }); - const rune = path.join(exampleDir, name + ".rune"); - - return fs.readFileSync(rune); +function dummyNode( + inputs: TensorDescriptor[], + outputs: TensorDescriptor[], + results: Record +): Node { + return { + graph: async (): Promise => { + return { inputs, outputs }; + }, + infer: () => Promise.resolve(results), + }; } diff --git a/bindings/web/rune/src/Runtime.ts b/bindings/web/rune/src/Runtime.ts index 53fc913dd6..fd05a608d4 100644 --- a/bindings/web/rune/src/Runtime.ts +++ b/bindings/web/rune/src/Runtime.ts @@ -1,373 +1,293 @@ -import Shape from "./Shape"; - -/** - * Something which consumes outputs generated by the Rune. - */ -export interface Output { - consume(data: Uint8Array): void; +import { Node } from "."; +import { runtime_v1 } from "@hotg-ai/rune-wit-files"; +import { TensorDescriptor, Tensors } from "./proc_blocks"; +import { DocumentV1, Stage } from "./Runefile"; +import { Tensor } from "."; +import { + isCapabilityStage, + isOutStage, + stageArguments, + stageInputs, +} from "./utils"; +import { Logger } from "pino"; + +type TensorId = number; +type NodeId = string; + +interface ProcBlockLike { + graph(args: Record): Tensors; + evaluate( + inputs: Record, + args: Record + ): Record; } -/** - * Inputs provided by the application. - */ -export interface Capability { - generate(dest: Uint8Array): void; - setParameter(name: string, value: number): void; +export async function create( + doc: DocumentV1, + procBlocks: Record, + models: Record, + logger: Logger +): Promise { + const pb = procBlockNodes(procBlocks); + const nodes: Record = { ...models, ...pb }; + + const { dependencies, evaluationOrder } = await getTensors( + doc.pipeline, + nodes + ); + return new Runtime(doc, nodes, dependencies, evaluationOrder, logger); } -/** - * Functions required by the Rune runtime. - */ -export interface Imports { - createOutput(type: number): Output; - createCapability(type: number): Capability; - createModel(mimetype: string, model: ArrayBuffer): Promise; - log(message: string | StructuredLogMessage): void; +type NodeDependencies = { + inputs: Record; + outputs: Record; +}; + +type Stuff = { + evaluationOrder: NodeId[]; + dependencies: Record; +}; + +function count(): () => number { + let n = 0; + return () => n++; } -/** - * Something which can run inference on a model. - */ -export interface Model { - transform( - inputArray: Uint8Array[], - inputDimensions: Shape[], - outputArray: Uint8Array[], - outputDimensions: Shape[], - ): void; +async function allGraphs( + pipeline: Record, + nodes: Record +): Promise> { + const promises = Object.entries(pipeline).map(async ([name, stage]) => { + const args = stageArguments(stage); + const node = nodes[name]; + const tensors = await node.graph(args); + return [name, tensors] as const; + }); + + return Object.fromEntries(await Promise.all(promises)); } -type TensorDescriptor = { - dimensions: string, -}; +async function getTensors( + pipeline: Record, + nodes: Record +): Promise { + let nextTensorId = count(); + const visited: NodeId[] = []; + const dependencies: Record = {}; + + const tensorConstraints = await allGraphs(pipeline, nodes); + + const inputs = Object.entries(pipeline) + .filter(([_, stage]) => isCapabilityStage(stage)) + .map(([name, _]) => name); + const toVisit = Object.entries(pipeline) + .filter(([_, stage]) => isOutStage(stage)) + .map(([name, _]) => name); + + // assume each capability node has 1 input + for (const name of inputs) { + const id = nextTensorId(); + dependencies[name] = { + inputs: { [`${name}.0`]: id }, + outputs: { [`${name}.0`]: id }, + }; + } -type ModelInfo = { - id: number, - modelSize: number, - inputs?: TensorDescriptor[], - outputs?: TensorDescriptor[], -}; + let node; -/** - * Public interface exposed by the WebAssembly module. - */ -interface Exports extends WebAssembly.Exports { - memory: WebAssembly.Memory; - _manifest(): void; - _call(capability_type: number, input_type: number, capability_index: number): void; + while ((node = toVisit.pop())) { + visited.push(node); + const stage = pipeline[node]; + + if (isCapabilityStage(stage) || isOutStage(stage)) { + continue; + } + + const outputs = tensorConstraints[node].outputs.map( + (desc) => [desc.name, nextTensorId()] as const + ); + dependencies[node] = { + inputs: {}, + outputs: Object.fromEntries(outputs), + }; + + stageInputs(stage) + .filter(({ node }) => !visited.includes(node)) + .forEach(({ node }) => toVisit.push(node)); + } + + for (const [stageName, stage] of Object.entries(pipeline)) { + const inputs = stageInputs(stage); + + inputs.forEach(({ node, index }, i) => { + const { name: previousTensorName } = + tensorConstraints[node].outputs[index]; + const { name: currentTensorName } = + tensorConstraints[stageName].inputs[i]; + dependencies[stageName].inputs[currentTensorName] = + dependencies[node].outputs[previousTensorName]; + }); + } + + visited.reverse(); + + return { + dependencies, + evaluationOrder: visited, + }; } export class Runtime { - instance: WebAssembly.Instance; + /** + * The tensors associated with each node. + */ + private tensors: Record = {}; + private logger: Logger; + + constructor( + private doc: DocumentV1, + private nodes: Record, + private dependencies: Record, + private evaluationOrder: NodeId[], + logger: Logger + ) { + this.logger = logger.child({ name: "Runtime" }); + } + + public async infer(): Promise { + this.logger.debug("Starting inference"); + const start = Date.now(); + + for (const name of this.evaluationOrder) { + this.evaluateNode(name); + } + + const durationMs = Date.now() - start; + this.logger.debug({ durationMs }, "Inference completed successfully"); + } - constructor(instance: WebAssembly.Instance) { - this.instance = instance; + public get inputs(): string[] { + const inputs: string[] = []; + + for (const [name, stage] of Object.entries(this.doc.pipeline)) { + // TODO: check for proc-blocks with no input tensors in the Runefile + if (isCapabilityStage(stage)) { + inputs.push(name); + } } - static async load(wasm: ArrayBuffer, imports: Imports) { - let memory: WebAssembly.Memory; + return inputs; + } - const { hostFunctions, finaliseModels } = importsToHostFunctions( - imports, - () => memory, - ); - const { instance } = await WebAssembly.instantiate(wasm, hostFunctions); - - const exports = instance.exports; - if (!isRuneExports(exports)) { - throw new Error("Invalid Rune exports"); - } - memory = exports.memory; - exports._manifest(); - - // now we've asked for all the models to be loaded, let's wait until - // they are done before continuing - await finaliseModels(); - return new Runtime(instance); + public setNodeInput(node: string, name: string, tensor: Tensor) { + if (!(node in this.dependencies)) { + throw new Error(); } - manifest() { - return this.exports._manifest(); + const { inputs } = this.dependencies[node]; + + if (!(name in inputs)) { + throw new Error(); } - call() { - this.exports._call(0, 0, 0); + const id = inputs[name]; + this.tensors[id] = tensor; + } + + private async evaluateNode(name: string) { + if (name in this.tensors) { + // already been evaluated + return; + } + + if (!(name in this.nodes)) { + throw new Error(`No "${name}" node registered`); + } + if (!(name in this.doc.pipeline)) { + throw new Error(`The Runefile doesn't contain a "${name}" node`); } - get exports() { - // Note: checked inside Runtime.load() and exports will never change. - const { exports } = this.instance; + const stage = this.doc.pipeline[name]; - if (isRuneExports(exports)) { - return exports; - } else { - throw Error(); - } + if (isOutStage(stage)) { + // output stages don't do anything. Note: we will be deleting output + // stages altogether. + return; } -} -type Dict = Partial>; + this.logger.debug({ node: name }, "Evaluating a node"); + const start = Date.now(); -/** - * Generate a bunch of host functions backed by the supplied @param imports. - */ -function importsToHostFunctions( - imports: Imports, - getMemory: () => WebAssembly.Memory, -) { - const memory = () => { - const m = getMemory(); - if (!m) - throw new Error("WebAssembly memory wasn't initialized"); - - return new Uint8Array(m.buffer); - }; + const node = this.nodes[name]; + const args = stageArguments(this.doc.pipeline[name]); - const ids = counter(); - const outputs: Dict = {}; - const capabilities: Dict = {}; - const pendingModels: Promise<[number, Model]>[] = []; - const models: Record = {}; - const modelsDescription: Record = {}; - const utf8 = new TextDecoder(); - const decoder = new TextDecoder("utf8"); - - // Annoyingly, this needs to be an object literal instead of a class. - const env = { - _debug(msg: number, len: number) { - const raw = memory().subarray(msg, msg + len); - const decoded = utf8.decode(raw); - const parsed = tryParseJSON(decoded); - - function tryParseJSON(input: string): any | undefined { - try { - return JSON.parse(input); - } catch { - return; - } - } - - if (isStructuredLogMessage(parsed)) { - imports.log(parsed); - - if (parsed.level == "ERROR") { - // Translate all errors inside the Rune into exceptions, - // aborting execution. - throw new Error(parsed.message); - } - } else { - imports.log(decoded); - } - }, - - request_output(type: number) { - const output = imports.createOutput(type); - const id = ids(); - - outputs[id] = output; - return id; - }, - - consume_output(id: number, buffer: number, len: number) { - const output = outputs[id]; - if (output) { - const data = memory().subarray(buffer, buffer + len); - output.consume(data); - } - else { - throw new Error("Invalid output"); - } - }, - - request_capability(type: number) { - const capability = imports.createCapability(type); - const id = ids(); - - capabilities[id] = capability; - return id; - }, - - request_capability_set_param(id: number, - keyPtr: number, - keyLength: number, - valuePtr: number, - valueLength: number, - valueType: number) { - const keyBytes = memory().subarray(keyPtr, keyPtr + keyLength); - const key = decoder.decode(keyBytes); - const bytes = memory().subarray(valuePtr, valuePtr + valueLength).slice(0); - const value = decodeValue(valueType, bytes); - - const capability = capabilities[id]; - - if (!capability) { - throw new Error(`Tried to set "${key}" to ${value} but capability ${id} doesn't exist`); - } - - capability.setParameter(key, value); - }, - - request_provider_response(buffer: number, len: number, id: number) { - const cap = capabilities[id]; - if (!cap) { - throw new Error("Invalid capability"); - } - const dest = memory().subarray(buffer, buffer + len); - - cap.generate(dest); - }, - - rune_model_load(mimetype: number, mimetype_len: number, model: number, model_len: number, input_descriptors: number, input_len: number, output_descriptors: number, output_len: number) { - const mime = decoder.decode(memory().subarray(mimetype, mimetype + mimetype_len)); - const model_data = memory().subarray(model, model + model_len); - - //inputs - let o = memory().subarray(input_descriptors, input_descriptors + 8 * input_len); - let inputs = []; - for (let i = 0; i < input_len; i++) { - const inputs_pointer = new Uint32Array(new Uint8Array([o[i * 8], o[i * 8 + 1], o[i * 8 + 2], o[i * 8 + 3]]).buffer)[0]; - const inputs_length = new Uint32Array(new Uint8Array([o[i * 8 + 4], o[i * 8 + 5], o[i * 8 + 6], o[i * 8 + 7]]).buffer)[0]; - const inputs_string = decoder.decode(memory().subarray(inputs_pointer, inputs_pointer + inputs_length)); - inputs.push({ "dimensions": inputs_string }); - } - //outputs - o = memory().subarray(output_descriptors, output_descriptors + 8 * output_len); - let outputs = []; - for (let i = 0; i < output_len; i++) { - const outputs_pointer = new Uint32Array(new Uint8Array([o[i * 8], o[i * 8 + 1], o[i * 8 + 2], o[i * 8 + 3]]).buffer)[0]; - const outputs_length = new Uint32Array(new Uint8Array([o[i * 8 + 4], o[i * 8 + 5], o[i * 8 + 6], o[i * 8 + 7]]).buffer)[0]; - const outputs_string = decoder.decode(memory().subarray(outputs_pointer, outputs_pointer + outputs_length)); - outputs.push({ "dimensions": outputs_string }); - } - - const pending = imports.createModel(mime, model_data); - const id = ids(); - - pendingModels.push(pending.then(model => [id, model])); - modelsDescription[id] = { id, inputs, outputs, "modelSize": model_len }; - return id; - }, - - async rune_model_infer(id: number, inputs: number, outputs: number) { - const model = models[id]; - let modelsDes = modelsDescription[id]; - - let inputArray = []; - let inputDimensions = []; - - for (let i = 0; i < modelsDes!.inputs!.length; i++) { - let dimensions = Shape.parse(modelsDes!.inputs![i].dimensions); - - let o = memory().subarray(inputs + i * 4, inputs + 4 + i * 4); - const pointer = new Uint32Array(new Uint8Array([o[0], o[1], o[2], o[3]]).buffer)[0]; - inputArray.push(memory().subarray(pointer, pointer + dimensions.byteSize)); - inputDimensions.push(dimensions); - } - - let outputArray = []; - let outputDimensions = []; - for (let i = 0; i < modelsDes!.outputs!.length; i++) { - let dimensions = Shape.parse(modelsDes!.outputs![i].dimensions); - let o = memory().subarray(outputs + i * 4, outputs + 4 + i * 4); - const pointer = new Uint32Array(new Uint8Array([o[0], o[1], o[2], o[3]]).buffer)[0]; - outputArray.push(memory().subarray(pointer, pointer + dimensions.byteSize)); - outputDimensions.push(dimensions); - } - model.transform(inputArray, inputDimensions, outputArray, outputDimensions); - return id; - }, - - tfm_model_invoke(id: number, inputPtr: number, inputLen: number, outputPtr: number, outputLen: number) { - deprecated("tfm_model_invoke()", "0.5"); - }, - tfm_preload_model(data: number, len: number, numInputs: number, numOutputs: number) { - deprecated("tfm_preload_model()", "0.5"); - }, - }; + const inputs = this.nodeInputs(name); - async function synchroniseModelLoading() { - const loadedModels = await Promise.all(pendingModels); - pendingModels.length = 0; - loadedModels.forEach(([id, model]) => { - models[id] = model; - }); - } - return { - hostFunctions: { env }, - finaliseModels: synchroniseModelLoading, - }; -} + const outputs = await node.infer(inputs, args); -function counter() { - let value = 0; - return () => { value++; return value - 1; }; -} + this.setNodeOutputs(name, outputs); -function isRuneExports(obj: any): obj is Exports { - return (obj && - obj.memory instanceof WebAssembly.Memory && - obj._call instanceof Function && - obj._manifest instanceof Function); -} + const durationMs = Date.now() - start; + this.logger.debug({ durationMs, node: name }, "Node evaluated"); + this.logger.trace({ inputs, outputs, args, node: name }); + } -export function isStructuredLogMessage(obj?: any): obj is StructuredLogMessage { - return obj - && typeof obj.level == 'string' - && typeof obj.message == 'string' - && typeof obj.target == 'string' - && typeof obj.module_path == 'string' - && typeof obj.file == 'string' - && typeof obj.line == 'number'; -} + private setNodeOutputs( + name: string, + outputs: Record + ) { + for (const [tensorName, id] of Object.entries( + this.dependencies[name].outputs + )) { + this.tensors[id] = outputs[tensorName]; + } + } -export type StructuredLogMessage = { - level: string, - message: string, - target: string, - module_path: string, - file: string, - line: number, -}; + private nodeInputs(node: string): Record { + const tensors: Record = {}; -interface TypedArray extends ArrayBuffer { - readonly buffer: ArrayBuffer; -} + for (const [name, id] of Object.entries(this.dependencies[node].inputs)) { + if (!(id in this.tensors)) { + throw new Error( + `The "${node}" node requires tensor ${id}, but it hasn't been set` + ); + } -//this function can convert any TypedArray to any other kind of TypedArray : -function convertTypedArray(src: TypedArray, constructor: any): T { - // Instantiate a buffer (zeroed out) and copy the bytes from "src" into it. - const buffer = new constructor(src.byteLength); - buffer.set(src.buffer); - return buffer[0] as T; + tensors[name] = this.tensors[id]; + } + + return tensors; + } } +function procBlockNodes( + procBlocks: Record +): Record { + const nodes: Record = {}; -function deprecated(feature: string, version: string) { - throw new Error(`This runtime no longer supports Runes using "${feature}". Please rebuild with Rune ${version}`); -} + for (const [name, procBlock] of Object.entries(procBlocks)) { + nodes[name] = new ProcBlockNode(procBlock); + } -function decodeValue(valueType: number, raw: Uint8Array): number { - const { buffer, byteOffset, byteLength } = raw; - const bytes = buffer.slice(byteOffset, byteOffset + byteLength); - - switch (valueType) { - case 1: - const i32s = new Int32Array(bytes); - return i32s[0]; - case 2: - const f32s = new Float32Array(bytes); - return f32s[0]; - case 5: - return raw[0]; - case 6: - const i16s = new Int16Array(bytes); - return i16s[0]; - case 7: - const i8s = new Int8Array(bytes); - return i8s[0]; - - default: - throw new Error(`Unknown value type, ${valueType}, with binary representation, ${raw}`); - } + return nodes; } +/** + * An adapter class that makes each ProcBlock method asynchronous. + */ +class ProcBlockNode implements Node { + constructor(private procBlock: ProcBlockLike) {} + + graph(args: Record): Promise { + const tensors = this.procBlock.graph(args); + return Promise.resolve(tensors); + } + + infer( + inputs: Record, + args: Record + ): Promise> { + const outputs = this.procBlock.evaluate(inputs, args); + return Promise.resolve(outputs); + } +} diff --git a/bindings/web/rune/src/Shape.test.ts b/bindings/web/rune/src/Shape.test.ts deleted file mode 100644 index fca966378f..0000000000 --- a/bindings/web/rune/src/Shape.test.ts +++ /dev/null @@ -1,20 +0,0 @@ -import Shape from "./Shape"; - -describe("Shape", () => { - it("can parse u8[1, 2,3]", () => { - const text = "u8[1, 2,3]"; - - const got = Shape.parse(text); - - expect(got).toEqual(new Shape("u8", [1, 2, 3])); - }); - - const knownShapes = ["u8[1]", "f32[2, 4, 6, 8]"] - - test.each(knownShapes)(`can round-trip %p`, input => { - const parsed = Shape.parse(input); - const stringified = parsed.toString(); - - expect(stringified).toEqual(input); - }); -}); diff --git a/bindings/web/rune/src/Shape.ts b/bindings/web/rune/src/Shape.ts deleted file mode 100644 index e4e296e29b..0000000000 --- a/bindings/web/rune/src/Shape.ts +++ /dev/null @@ -1,90 +0,0 @@ - -/** - * A description of a tensor. - */ -export default class Shape { - static ByteSize = { - "f64": 8, - "i64": 8, - "u64": 8, - "f32": 4, - "i32": 4, - "u32": 4, - "u16": 2, - "i16": 2, - "u8": 1, - "i8": 1 - } as const; - - /** - * The element type. - */ - readonly type: string; - /** - * The tensor's dimensions. - */ - readonly dimensions: readonly number[]; - - constructor(type: string, values: number[]) { - this.type = type; - this.dimensions = [...values]; - } - - /** - * Parse a string like "u8[1, 2, 3]" into a Shape. - */ - static parse(text: string): Shape { - const pattern = /^([\w\d]+)\[(\d+(?:,\s*\d+)*)\]$/; - const match = pattern.exec(text.replace(" ", "")); - - if (!match) { - throw new Error(); - } - - const [_, typeName, dims] = match; - - checkElementType(typeName, text); - - return new Shape(typeName, dims.split(",").map(d => parseInt(d.trim()))); - } - - /** - * The number of dimensions this tensor has. - */ - get rank(): number { - return this.dimensions.length; - } - - /** - * The number of elements in this tensor. - */ - get tensorSize(): number { - return this.dimensions.reduce((product, dim) => product * dim, 1); - } - - /** - * The number of bytes used to store this tensor's elements. - */ - get byteSize(): number { - const sizes: Record = Shape.ByteSize; - const elementSize = sizes[this.type] || 1; - return this.tensorSize * elementSize; - } - - toString(): string { - const { type, dimensions } = this; - const dims = dimensions.join(", "); - return `${type}[${dims}]`; - } -} - - -function checkElementType(typeName: string, input: string) { - const knownElements = Object.keys(Shape.ByteSize); - - if (typeName in Shape.ByteSize) { - return; - } - - console.warn(`The "${typeName}" in "${input}" isn't one of the known element types (${knownElements})`); -} diff --git a/bindings/web/rune/src/Tensor.test.ts b/bindings/web/rune/src/Tensor.test.ts deleted file mode 100644 index 892c858d63..0000000000 --- a/bindings/web/rune/src/Tensor.test.ts +++ /dev/null @@ -1,43 +0,0 @@ -import { Shape, Tensor } from "."; - -describe("Tensor", () => { - it("can be round tripped as a Uint8Array", () => { - const shape = new Shape("u8", [2, 3]); - const tensor = new Tensor(shape, new Uint8Array([1, 2, 3, 4, 5, 6])); - - const typed = tensor.asTypedArray("u8"); - - expect(Array.from(typed)).toEqual([1, 2, 3, 4, 5, 6]); - }); - - it("can be viewed as a Float32Array", () => { - const raw = new Uint8Array([0, 0, 64, 64]); - const shape = new Shape("f32", [1]); - - const tensor = new Tensor(shape, raw); - const typed = tensor.asTypedArray("f32"); - - expect(Array.from(typed)).toEqual([3.0]); - }); - - it("can be a slice from a larger buffer", () => { - const numbers = [1, 2, 3, 4, 5, 6, 7, 8]; - const buffer = new Float32Array(numbers); - const section = buffer.subarray(3, 6); - const shape = new Shape("f32", [3]); - - const tensor = new Tensor(shape, new Uint8Array(section.buffer, section.byteOffset, section.byteLength)); - const typed = tensor.asTypedArray("f32"); - - expect(Array.from(typed)).toEqual(numbers.slice(3, 6)); - }); - - it("can be constructed from a typed array", () => { - const values = [1, 2, 3, 4, -5, -6]; - const raw = new Int16Array(values); - - const tensor = Tensor.fromTypedArray("i16", [6], raw); - - expect(Array.from(tensor.asTypedArray("i16"))).toEqual(values); - }); -}); diff --git a/bindings/web/rune/src/Tensor.ts b/bindings/web/rune/src/Tensor.ts deleted file mode 100644 index 86cb53fa55..0000000000 --- a/bindings/web/rune/src/Tensor.ts +++ /dev/null @@ -1,151 +0,0 @@ -import Shape from "./Shape"; - -// Some versions of Safari doesn't support BigUint64Array and friends, and -// it's not possible to polyfill these types because bigint is a builtin type. -// -// This workaround lets us use them when possible and throws an exception at -// runtime when they aren't. -const BigUint64ArrayShim = global.BigUint64Array ?? class { constructor() { throw new Error("BigUint64Array is not supported on this device"); } }; -const BigInt64ArrayShim = global.BigInt64Array ?? class { constructor() { throw new Error("BigInt64Array is not supported on this device"); } }; - -const typedArrayConstructors = { - "f64": Float64Array, - "i64": BigInt64ArrayShim, - "u64": BigUint64ArrayShim, - "f32": Float32Array, - "i32": Int32Array, - "u32": Uint32Array, - "u16": Uint16Array, - "i16": Int16Array, - "u8": Uint8ClampedArray, - "i8": Int8Array, -} as const; - -type TypedArrayConstructors = typeof typedArrayConstructors; - -export type TypedArrays = { - [Key in keyof TypedArrayConstructors]: InstanceType; -} - -/** - * An opaque tensor. - */ -export default class Tensor { - /** - * The raw bytes containing the tensor data. - */ - public readonly elements: Uint8Array; - /** - * The tensor's shape (element type and dimensions). - */ - public readonly shape: Shape; - - constructor(shape: Shape, elements: Uint8Array) { - this.shape = shape; - this.elements = elements; - } - - /** - * Construct a new Tensor from a typed array containing its flattened - * elements in row-major order. - * - * @param elementType The type of the element - * @param dimensions The tensor's dimensions - * @param elements The elements - * @returns - */ - public static fromTypedArray( - elementType: S, - dimensions: readonly number[], - elements: TypedArrays[S], - ): Tensor { - const { buffer, byteLength, byteOffset } = elements; - const shape = new Shape(elementType, [...dimensions]); - return new Tensor(shape, new Uint8Array(buffer, byteOffset, byteLength)); - } - - /** - * View this tensor's data as an array of 64-bit floats. - * - * This will fail if this isn't a f64 tensor. - */ - public asTypedArray(elementType: "f64"): Float64Array; - /** - * View this tensor's data as an array of 64-bit signed integers. - * - * This will fail if this isn't a i64 tensor. It may also fail on - * versions of Safari because they don't support BigInt64Array. - */ - public asTypedArray(elementType: "i64"): BigInt64Array; - /** - * View this tensor's data as an array of 64-bit unsigned integers. - * - * This will fail if this isn't a u64 tensor. It may also fail on - * versions of Safari because they don't support BigUint64Array. - */ - public asTypedArray(elementType: "u64"): BigUint64Array; - /** - * View this tensor's data as an array of 32-bit floats. - * - * This will fail if this isn't a f32 tensor. - */ - public asTypedArray(elementType: "f32"): Float32Array; - /** - * View this tensor's data as an array of 32-bit signed integers. - * - * This will fail if this isn't a i32 tensor. - */ - public asTypedArray(elementType: "i32"): Int32Array; - /** - * View this tensor's data as an array of 32-bit unsigned integers. - * - * This will fail if this isn't a u32 tensor. - */ - public asTypedArray(elementType: "u32"): Uint32Array; - /** - * View this tensor's data as an array of 16-bit signed integers. - * - * This will fail if this isn't a i16 tensor. - */ - public asTypedArray(elementType: "i16"): Int16Array; - /** - * View this tensor's data as an array of 16-bit unsigned integers. - * - * This will fail if this isn't a u16 tensor. - */ - public asTypedArray(elementType: "u16"): Uint16Array; - /** - * View this tensor's data as an array of 8-bit signed integers. - * - * This will fail if this isn't a i8 tensor. - */ - public asTypedArray(elementType: "i8"): Int8Array; - /** - * View this tensor's data as an array of 8-bit unsigned integers. - * - * This will fail if this isn't a u8 tensor. - */ - public asTypedArray(elementType: "u8"): Uint8ClampedArray; - - public asTypedArray(elementType: keyof typeof typedArrayConstructors): ArrayBuffer { - if (this.shape.type != elementType) { - throw new Error(`Attempting to interpret a ${this.shape.toString()} as a ${elementType} tensor`); - } - - const { buffer, byteOffset, byteLength } = this.elements; - const length = byteLength / Shape.ByteSize[this.shape.type]; - const constructor = typedArrayConstructors[elementType]; - - return new constructor(buffer, byteOffset, length); - } - - public get elementType(): string { - return this.shape.type; - } - - public get dimensions(): readonly number[] { - return this.shape.dimensions; - } -} - -const x = Tensor.fromTypedArray diff --git a/bindings/web/rune/src/__fixtures__/sine.zip b/bindings/web/rune/src/__fixtures__/sine.zip new file mode 100644 index 0000000000..775334e60d Binary files /dev/null and b/bindings/web/rune/src/__fixtures__/sine.zip differ diff --git a/bindings/web/rune/src/__test__/index.ts b/bindings/web/rune/src/__test__/index.ts new file mode 100644 index 0000000000..d2faf9f5c4 --- /dev/null +++ b/bindings/web/rune/src/__test__/index.ts @@ -0,0 +1,175 @@ +import pino, { Logger } from "pino"; +import { ElementType, Tensor } from ".."; +import { isTensor } from "../utils"; + +function stringArray(buffer: ArrayBuffer, byteOffset: number, length: number) { + const reader = new DataView(buffer, byteOffset, length); + const decoder = new TextDecoder(); + const strings: string[] = []; + + let offset = 0; + while (offset < reader.byteLength) { + const length = reader.getUint32(offset, true); + const utf8 = new Uint8Array(buffer, byteOffset + offset, length); + strings.push(decoder.decode(utf8)); + offset += 4 + length; + } + + return strings; +} + +function typedArray(constructor: { + new ( + buffer: ArrayBufferLike, + byteOffset: number, + length: number + ): ArrayLike; + readonly BYTES_PER_ELEMENT: number; +}): (b: ArrayBuffer, off: number, len: number) => T[] { + return (b, off, len) => + Array.from(new constructor(b, off, len / constructor.BYTES_PER_ELEMENT)); +} + +type NumericTensor = { + elementType: "u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "f32" | "f64"; + dimensions: number[]; + elements: number[]; +}; + +type StringTensor = { + elementType: "utf8"; + dimensions: number[]; + elements: string[]; +}; + +type BigIntTensor = { + elementType: "u64" | "i64"; + dimensions: number[]; + elements: bigint[]; +}; + +type FormattedTensor = NumericTensor | StringTensor | BigIntTensor; + +export function formatTensor(tensor: Tensor): FormattedTensor { + const dimensions = Array.from(tensor.dimensions); + + const { buffer, byteOffset, byteLength } = tensor.buffer; + + switch (tensor.elementType) { + case ElementType.U8: + return { + elementType: "u8", + dimensions, + elements: typedArray(Uint8Array)(buffer, byteOffset, byteLength), + }; + case ElementType.I8: + return { + elementType: "i8", + dimensions, + elements: typedArray(Int8Array)(buffer, byteOffset, byteLength), + }; + case ElementType.U16: + return { + elementType: "u16", + dimensions, + elements: typedArray(Uint16Array)(buffer, byteOffset, byteLength), + }; + case ElementType.I16: + return { + elementType: "i16", + dimensions, + elements: typedArray(Int16Array)(buffer, byteOffset, byteLength), + }; + case ElementType.U32: + return { + elementType: "u32", + dimensions, + elements: typedArray(Uint32Array)(buffer, byteOffset, byteLength), + }; + case ElementType.I32: + return { + elementType: "i32", + dimensions, + elements: typedArray(Int32Array)(buffer, byteOffset, byteLength), + }; + case ElementType.F32: + return { + elementType: "f32", + dimensions, + elements: typedArray(Float32Array)(buffer, byteOffset, byteLength), + }; + case ElementType.F64: + return { + elementType: "f64", + dimensions, + elements: typedArray(Float64Array)(buffer, byteOffset, byteLength), + }; + case ElementType.U64: + return { + elementType: "u64", + dimensions, + elements: typedArray(BigUint64Array)(buffer, byteOffset, byteLength), + }; + case ElementType.I64: + return { + elementType: "i64", + dimensions, + elements: typedArray(BigInt64Array)(buffer, byteOffset, byteLength), + }; + case ElementType.Utf8: + return { + elementType: "utf8", + dimensions, + elements: stringArray(buffer, byteOffset, byteLength), + }; + } +} + +export function testLogger(): Logger { + const bindings = () => { + const { currentTestName } = expect.getState(); + return { test: currentTestName }; + }; + + const humanReadableTensors = (object: any): any => { + if (isTensor(object)) { + return formatTensor(object); + } + + if (typeof object == "object") { + const formatted: any = {}; + + for (const key in object) { + formatted[key] = humanReadableTensors(object[key]); + } + + return formatted; + } + + return object; + }; + + const logger = pino({ + level: "trace", + nestedKey: "payload", + timestamp: false, + formatters: { + bindings, + log: humanReadableTensors, + level: (label) => ({ level: label }), + }, + }); + + beforeEach(() => { + const { currentTestName } = expect.getState(); + logger.info({ test: currentTestName }, "Starting Test"); + }); + + afterEach(() => { + const { currentTestName } = expect.getState(); + logger.info({ test: currentTestName }, "Completed Test"); + logger.flush(); + }); + + return logger; +} diff --git a/bindings/web/rune/src/builtin/RandomCapability.ts b/bindings/web/rune/src/builtin/RandomCapability.ts deleted file mode 100644 index 1fbef661b3..0000000000 --- a/bindings/web/rune/src/builtin/RandomCapability.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { Capability } from "../Runtime"; - -export class RandomCapability implements Capability { - setParameter(name: string, value: number): void { - // Note: we don't have any configurable settings - } - - generate(dest: Uint8Array): void { - window.crypto.getRandomValues(dest); - } -} diff --git a/bindings/web/rune/src/builtin/WebcamCapability.ts b/bindings/web/rune/src/builtin/WebcamCapability.ts deleted file mode 100644 index e592cb5f54..0000000000 --- a/bindings/web/rune/src/builtin/WebcamCapability.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { Capability } from "../Runtime"; - -type Properties = { - width: number, - height: number, -}; - -export class WebcamCapability implements Capability { - lastImage: any; - properties: Properties = { - width: 320, - height: 320, - }; - - generate(dest: Uint8Array): void { - // TODO: Figure out how to read from the webcam. - throw new Error("Method not implemented."); - } - - setParameter(name: string, value: number): void { - const properties: Record = this.properties; - properties[name] = value; - } -} diff --git a/bindings/web/rune/src/builtin/index.ts b/bindings/web/rune/src/builtin/index.ts deleted file mode 100644 index 81408585f9..0000000000 --- a/bindings/web/rune/src/builtin/index.ts +++ /dev/null @@ -1,12 +0,0 @@ -export { RandomCapability } from "./RandomCapability"; -export { WebcamCapability } from "./WebcamCapability"; - -/** - * Mimetypes for the model formats known by rune. - */ -export const mimetypes = { - tflite: "application/tflite-model", - tensorflow: "application/tf-model", - tfjs: "application/tfjs-model", - onnx: "application/onnx-model", -} as const; diff --git a/bindings/web/rune/src/facade.ts b/bindings/web/rune/src/facade.ts deleted file mode 100644 index c4c3b070ee..0000000000 --- a/bindings/web/rune/src/facade.ts +++ /dev/null @@ -1,313 +0,0 @@ -import { Capabilities, CapabilityType, Outputs, Shape } from "."; -import { Capability, Imports, Model, Output, Runtime, StructuredLogMessage } from "./Runtime"; -import Tensor from "./Tensor"; - -type ModelConstructor = (model: ArrayBuffer) => Promise; -type Logger = (message: string | StructuredLogMessage) => void; - -export type InputDescription = { - type: CapabilityType, - args: Partial>, -}; - -/** - * A function that returns the desired input, either as a tensor or the raw - * byte buffer. - */ -export type ReadInput = (input: InputDescription) => Tensor; - -/** - * A function which can be used to evaluate a Rune. - */ -export type Evaluate = (r: ReadInput) => Result; - -/** - * A builder object which can be used to initialize the Rune runtime. - */ -export class Builder { - private modelHandlers: Partial> = {}; - private log: Logger = () => { }; - - /** - * Set a handler that will be called every time the Rune logs a message. - */ - public onDebug(handler: Logger): this { - this.log = handler; - return this; - } - - /** - * Add support for a new type of model. - * @param mimetype The "mimetype" that specifies which type of model being - * handled. - * @param constructor A constructor which will load the model. - * @returns - */ - public withModelHandler(mimetype: string, constructor: ModelConstructor): this { - this.modelHandlers[mimetype] = constructor; - - return this; - } - - public async build(rune: ArrayBuffer | string): Promise { - if (typeof rune == "string") { - const response = await fetch(rune); - rune = await response.arrayBuffer(); - } - const { modelHandlers, log } = this; - - const imports = new ImportsObject(modelHandlers, log); - let runtime: Runtime | undefined = await Runtime.load(rune, imports); - - return readInputs => { - if (!runtime) { - throw new Error("A previous call to this Rune has failed, leaving it in an invalid state"); - } - - imports.setInputs(readInputs); - - try { - runtime.call(); - } catch (e) { - // We encountered an error while invoking the Rune, typically by - // throwing an exception from one of our host functions. JS - // exceptions abort execution without unwinding the - // WebAssembly/Rust stack so we need to assume the runtime is - // FUBAR. - runtime = undefined; - throw e; - } - - let outputs = [...imports.outputs]; - imports.outputs.length = 0; - - return { outputs }; - }; - } -} - -export type Result = { - outputs: OutputValue[], -}; - -/** - * A tensor value generated by the SERIAL output. - */ -export type OutputValue = { - /** - * An integer specifying which SERIAL output this is attached to. - */ - channel: number, - /** - * The tensor's dimensions. - */ - dimensions: number[], - /** - * The elements in this tensor, flattened into a single array in row-major - * order. - */ - elements: string[] | number[], - /** - * The Rust name for this tensor's element type. - */ - type_name: string, -} - -class ImportsObject implements Imports { - private decoder = new TextDecoder("utf8"); - outputs: Array = []; - private modelHandlers: Partial>; - private logger: Logger; - private capabilities: LazyCapability[] = []; - - constructor( - modelHandlers: Partial>, - logger: Logger, - ) { - this.modelHandlers = modelHandlers; - this.logger = logger; - } - - setInputs(readInput: ReadInput) { - const inputs = this.capabilities.map(c => c.description()).map(readInput); - - for (let i = 0; i < this.capabilities.length; i++) { - this.capabilities[i].value = inputs[i]; - } - } - - createOutput(type: number): Output { - const { decoder, outputs } = this; - switch (type) { - case Outputs.tensor: - return tensorOutput(decoder, outputs); - case Outputs.serial: - return serialOutput(decoder, outputs); - - default: - throw new Error(`Unsupported output type: ${type}`); - } - } - - createCapability(type: number): Capability { - const pair = Object.entries(Capabilities).find(pair => pair[1] == type); - if (!pair) { - throw new Error(`Unable to handle capability number ${type}`); - } - - const capabilityType = pair[0]; - const cap = new LazyCapability(capabilityType as CapabilityType); - this.capabilities.push(cap); - - return cap; - } - - createModel(mimetype: string, model: ArrayBuffer): Promise { - const handler = this.modelHandlers[mimetype]; - - if (!handler) { - throw new Error(`No handler registered for "${mimetype}" models`); - } - - return handler(model); - } - - log(message: string | StructuredLogMessage): void { - this.logger(message); - } -} - -function tensorOutput(decoder: TextDecoder, outputs: Array): Output { - return { - consume: ({ buffer, byteLength, byteOffset }: Uint8Array) => { - const shapeLength = new Uint32Array(buffer, byteOffset, 1)[0]; - const shapeBytes = new Uint8Array(buffer, byteOffset + 4, shapeLength); - const shape = Shape.parse(decoder.decode(shapeBytes)); - const { type, dimensions } = shape; - const elements = new Uint8Array(buffer, byteOffset + 4 + shapeLength, byteLength - 4 - shapeLength); - const tensor = new Tensor(shape, elements); - outputs.push({ - channel: -1, - dimensions: [...dimensions], - type_name: type, - elements: tensorAsNumberArray(tensor), - }) - } - } -} - -function tensorAsNumberArray(tensor: Tensor): number[] { - const { elementType } = tensor; - - switch (elementType) { - case "f32": - const floats = tensor.asTypedArray(elementType); - return Array.from(floats); - case "u8": - const u8s = tensor.asTypedArray(elementType); - return Array.from(u8s); - case "u16": - const u16s = tensor.asTypedArray(elementType); - return Array.from(u16s); - case "u32": - const u32s = tensor.asTypedArray(elementType); - return Array.from(u32s); - case "i8": - const i8s = tensor.asTypedArray(elementType); - return Array.from(i8s); - case "i16": - const i16s = tensor.asTypedArray(elementType); - return Array.from(i16s); - case "i32": - const i32s = tensor.asTypedArray(elementType); - return Array.from(i32s); - - default: - throw new Error( - `Unable to convert a ${tensor.shape.toString()} to a list of numbers` - ); - } -} - -function serialOutput(decoder: TextDecoder, outputs: Array): Output { - // We want the end user to receive all outputs as a return value, but - // Runes are designed using a callback-based API (it's better for - // performance). This will create an output which will stash all - // generated values away in a list so they can be returned at the end. - - return { - consume(data: Uint8Array) { - const json = decoder.decode(data); - const deserialized = JSON.parse(json); - - if (isOutputValue(deserialized)) { - outputs.push(deserialized); - } else if (Array.isArray(deserialized) && deserialized.every(isOutputValue)) { - outputs.push(...deserialized); - } else { - throw new SerialDeserializeError(json, deserialized); - } - } - } -} - -function isNumberArray(value: any): value is number[] { - return Array.isArray(value) && value.every(v => typeof v === "number"); -} - -function isStringArray(value: any): value is string[] { - return Array.isArray(value) && value.every(v => typeof v === "string"); -} - -function isOutputValue(value?: any): value is OutputValue { - if (!value) { - return false; - } - - const { channel, dimensions, elements, type_name } = value; - - return typeof channel === "number" && - isNumberArray(dimensions) && - (isNumberArray(elements) || isStringArray(elements)) && - typeof type_name === "string"; -} - -class LazyCapability implements Capability { - type: CapabilityType; - value?: Tensor; - args: Record = {}; - - constructor(type: CapabilityType) { - this.type = type; - } - - description(): InputDescription { - return { - type: this.type, - args: this.args, - }; - } - - generate(dest: Uint8Array): void { - if (!this.value) { - throw new Error(); - } - - dest.set(this.value.elements); - } - - setParameter(name: string, value: number): void { - this.args[name] = value; - } -} - -class SerialDeserializeError extends Error { - readonly json: string; - readonly deserialized?: any; - - constructor(json: string, deserialized: any | undefined) { - super("Unable to deserialize the SERIAL output"); - this.json = json; - this.deserialized = deserialized; - } -} diff --git a/bindings/web/rune/src/index.test.ts b/bindings/web/rune/src/index.test.ts new file mode 100644 index 0000000000..f62729a4aa --- /dev/null +++ b/bindings/web/rune/src/index.test.ts @@ -0,0 +1,89 @@ +import fs from "fs"; +import path from "path"; +import { Node, ElementType, Tensor, Rune } from "."; +import { Tensors } from "./proc_blocks"; +import { floatTensor } from "./utils"; +import { testLogger } from "./__test__"; + +describe("Integration Tests", () => { + let logger = testLogger(); + + const sine = new Uint8Array( + fs.readFileSync(path.join(__dirname, "__fixtures__", "sine.zip")) + ); + + it("can load the sine Rune", async () => { + const loader = new Rune(); + + const runtime = await loader + .withModelHandler("tensorflow-lite", async () => new DummySineModel()) + .withLogger(logger) + .load(sine); + + runtime.setInput("rand", floatTensor([1])); + + await runtime.infer(); + + expect(runtime.outputs).toEqual({ + serial: [floatTensor([Math.sin(1)])], + }); + }); +}); + +/** + * A "model" that executes sine() against each element in the tensor it is + * given. + */ +class DummySineModel implements Node { + async graph(): Promise { + const tensor = { + elementType: ElementType.F32, + dimensions: { + tag: "fixed", + val: Uint32Array.from([1, 1]), + }, + } as const; + + return { + inputs: [{ ...tensor, name: "input" }], + outputs: [{ ...tensor, name: "output" }], + }; + } + + async infer( + inputs: Record, + args: Record + ): Promise> { + const { + input: { + buffer: { buffer, byteLength, byteOffset }, + dimensions, + elementType, + }, + } = inputs; + + if (elementType != ElementType.F32) { + throw new Error("Invalid element type"); + } + + const floats = new Float32Array( + buffer, + byteOffset, + byteLength / Float32Array.BYTES_PER_ELEMENT + ); + + const result = floats.map(Math.sin); + + return { + output: { + elementType: ElementType.F32, + dimensions, + buffer: new Uint8Array( + result.buffer, + result.byteOffset, + result.byteLength + ), + }, + }; + } +} diff --git a/bindings/web/rune/src/index.ts b/bindings/web/rune/src/index.ts index 023fa3eb7e..a0abd7194c 100644 --- a/bindings/web/rune/src/index.ts +++ b/bindings/web/rune/src/index.ts @@ -1,45 +1,89 @@ -export { InputDescription, OutputValue, ReadInput, Result, Builder, Evaluate } from "./facade"; -export { default as Shape } from "./Shape"; -export { default as Tensor } from "./Tensor"; +import { runtime_v1 } from "@hotg-ai/rune-wit-files"; +import { Logger } from "pino"; +import { TensorDescriptor, Tensors } from "./proc_blocks"; -import { Builder } from "./facade"; +export { Rune } from "./Rune"; +export * from "./proc_blocks"; -/** - * A map of capability names to their identifies. - */ -export const Capabilities = { - "rand": 1, - "sound": 2, - "accel": 3, - "image": 4, - "raw": 5, - "float-image": 6, -} as const; +export type Tensor = runtime_v1.Tensor; +export const ElementType = runtime_v1.ElementType; +export type ElementType = runtime_v1.ElementType; +export type Dimensions = runtime_v1.Dimensions; /** - * A map of output names to their identifies. + * A callback that can be used to load models. + * + * The callback is given the model's bytes, arguments that were associated with + * this particular node, and a logger that should be used for logging. */ -export const Outputs = { - "serial": 1, - "tensor": 5, -} as const; +export type ModelHandler = ( + model: ArrayBuffer, + args: Record, + logger: Logger +) => Promise; /** - * The name of all known capabilities. + * A node in the Rune pipeline. */ -export type CapabilityType = keyof typeof Capabilities; +export interface Node { + /** + * Given the provided set of arguments, what are do node's input and output + * tensors look like? + */ + graph(args: Record): Promise; -/** - * The name of all known outputs. - */ -export type OutputType = keyof typeof Outputs; + /** + * Evaluate this node. + * + * @param inputs The node's input tensors. + * @param args Arguments that may alter this node's behaviour. + */ + infer( + inputs: Record, + args: Record + ): Promise>; +} + +type NamedTensor = { readonly id: number } & Readonly; + +type NodeInfo = { + readonly inputs: readonly NamedTensor[]; + readonly outputs: readonly NamedTensor[]; + readonly args: Readonly>; +}; + +type Pipeline = { + readonly inputNodes: readonly string[]; + readonly outputNodes: readonly string[]; + readonly nodes: Readonly>; +}; /** - * Use a high level builder API to initialize the Rune runtime. - * - * Check out the "Runtime" module if you need tighter control over the runtime - * or want to avoid unnecessary indirection/copies. + * An instantiated Rune. */ -export function builder(): Builder { - return new Builder(); +export interface Runtime { + readonly pipeline: Pipeline; + + /** + * Run the Rune's pipeline. + */ + infer(): Promise; + + /** + * Set the tensor to be used for a particular node's input by index. + */ + setInputTensor(node: string, index: number, tensor: Tensor): void; + /** + * Set the tensor to be used for a particular node's input by name. + */ + setInputTensor(node: string, name: string, tensor: Tensor): void; + + /** + * Get a node's output tensor by index. + */ + getOutputTensor(node: string, index: number): Tensor|undefined; + /** + * Get a node's output tensor by name. + */ + getOutputTensor(node: string, name: string): Tensor|undefined; } diff --git a/bindings/web/rune/src/proc_blocks/HostFunctions.ts b/bindings/web/rune/src/proc_blocks/HostFunctions.ts new file mode 100644 index 0000000000..d81635268b --- /dev/null +++ b/bindings/web/rune/src/proc_blocks/HostFunctions.ts @@ -0,0 +1,284 @@ +import { runtime_v1 } from "@hotg-ai/rune-wit-files"; +import { Logger, Level, levels } from "pino"; +import type { + ArgumentHint, + ArgumentMetadata, + Metadata, + SupportedShapes, + TensorHint, + TensorMetadata, + TensorDescriptor, +} from "."; + +const logLevels: Record = { + [runtime_v1.LogLevel.Trace]: "trace", + [runtime_v1.LogLevel.Debug]: "debug", + [runtime_v1.LogLevel.Info]: "info", + [runtime_v1.LogLevel.Warn]: "warn", + [runtime_v1.LogLevel.Error]: "error", + [runtime_v1.LogLevel.Fatal]: "fatal", +}; + +export class HostFunctions implements runtime_v1.RuntimeV1 { + metadata?: Metadata; + graph?: GraphContext; + kernel?: KernelContext; + + constructor(private logger: Logger) {} + + metadataNew(name: string, version: string): runtime_v1.Metadata { + return new MetadataBuilder({ + name, + version, + arguments: [], + inputs: [], + outputs: [], + tags: [], + }); + } + + argumentMetadataNew(name: string): runtime_v1.ArgumentMetadata { + return new ArgumentMetadataBuilder({ name, hints: [] }); + } + + tensorMetadataNew(name: string): runtime_v1.TensorMetadata { + return new TensorMetadataBuilder({ name, hints: [] }); + } + + interpretAsImage(): TensorHint { + return { type: "media-hint", media: "image" }; + } + + interpretAsAudio(): TensorHint { + return { type: "media-hint", media: "audio" }; + } + + supportedShapes( + supportedElementTypes: runtime_v1.ElementType[], + dimensions: runtime_v1.Dimensions + ): SupportedShapes { + return { + type: "supported-shapes", + supportedElementTypes, + dimensions, + }; + } + + interpretAsNumberInRange(min: string, max: string): ArgumentHint { + return { type: "number-in-range", min, max }; + } + + interpretAsStringInEnum(stringEnum: string[]): ArgumentHint { + return { type: "string-enum", possibleValues: stringEnum }; + } + + nonNegativeNumber(): ArgumentHint { + return { type: "non-negative-number" }; + } + + supportedArgumentType(hint: runtime_v1.ArgumentType): ArgumentHint { + return { type: "supported-argument-type", argumentType: hint }; + } + + registerNode(metadata: runtime_v1.Metadata): void { + if (metadata instanceof MetadataBuilder) { + this.metadata = metadata.meta; + } + } + + graphContextForNode(nodeId: string): runtime_v1.GraphContext | null { + return this.graph || null; + } + + kernelContextForNode(nodeId: string): runtime_v1.KernelContext | null { + return this.kernel || null; + } + + isEnabled(meta: runtime_v1.LogMetadata): boolean { + const requestedLevel = logLevels[meta.level]; + const threshold = this.logger.levelVal; + return levels.values[requestedLevel] > threshold; + } + + log( + metadata: runtime_v1.LogMetadata, + message: string, + data: runtime_v1.LogValueMap + ): void { + const payload = data.map(([key, value]) => { + return value.tag == "null" ? [key, null] : [key, value.val]; + }); + + const level = logLevels[metadata.level]; + const log = this.logger[level]; + + log({ metadata, payload: Object.fromEntries(payload) }, message); + } + + modelLoad( + modelFormat: string, + model: Uint8Array, + args: [string, string][] + ): runtime_v1.Result { + throw new Error("Method not implemented."); + } +} + +class MetadataBuilder implements runtime_v1.Metadata { + constructor(public meta: Metadata) {} + + setDescription(description: string): void { + this.meta.description = description; + } + + setRepository(url: string): void { + if (url) { + this.meta.repository = url; + } + } + + setHomepage(url: string): void { + if (url) { + this.meta.homepage = url; + } + } + + addTag(tag: string): void { + this.meta.tags.push(tag); + } + + addArgument(arg: runtime_v1.ArgumentMetadata): void { + if (arg instanceof ArgumentMetadataBuilder) { + this.meta.arguments.push(arg.meta); + } + } + + addInput(metadata: runtime_v1.TensorMetadata): void { + if (metadata instanceof TensorMetadataBuilder) { + this.meta.inputs.push(metadata.meta); + } + } + + addOutput(metadata: runtime_v1.TensorMetadata): void { + if (metadata instanceof TensorMetadataBuilder) { + this.meta.outputs.push(metadata.meta); + } + } +} + +class ArgumentMetadataBuilder implements runtime_v1.ArgumentMetadata { + constructor(public meta: ArgumentMetadata) {} + + setDescription(description: string): void { + this.meta.description = description; + } + + setDefaultValue(defaultValue: string): void { + this.meta.defaultValue = defaultValue; + } + + addHint(hint: runtime_v1.ArgumentHint): void { + if (isArgumentHint(hint)) { + this.meta.hints.push(hint); + } + } +} + +class TensorMetadataBuilder implements runtime_v1.TensorMetadata { + constructor(public meta: TensorMetadata) {} + + setDescription(description: string): void { + this.meta.description = description; + } + + addHint(hint: runtime_v1.TensorHint): void { + if (isTensorHint(hint)) { + this.meta.hints.push(hint); + } + } +} + +function isArgumentHint(value?: any): value is ArgumentHint { + const types: Array = [ + "non-negative-number", + "number-in-range", + "string-enum", + "supported-argument-type", + ]; + + return types.includes(value?.type); +} + +function isTensorHint(value?: any): value is TensorHint { + const types: Array = ["media-hint", "supported-shapes"]; + + return types.includes(value?.type); +} + +export class GraphContext implements runtime_v1.GraphContext { + inputs: TensorDescriptor[] = []; + outputs: TensorDescriptor[] = []; + + constructor(private args: Record) {} + + getArgument(name: string): string | null { + if (name in this.args) { + return this.args[name]; + } else { + return null; + } + } + + addInputTensor( + name: string, + elementType: runtime_v1.ElementType, + dimensions: runtime_v1.Dimensions + ): void { + this.inputs.push({ name, elementType, dimensions }); + } + + addOutputTensor( + name: string, + elementType: runtime_v1.ElementType, + dimensions: runtime_v1.Dimensions + ): void { + this.outputs.push({ name, elementType, dimensions }); + } +} + +export class KernelContext implements runtime_v1.KernelContext { + public outputs: Record = {}; + + constructor( + private args: Record, + private inputs: Record + ) {} + + getArgument(name: string): string | null { + if (name in this.args) { + return this.args[name]; + } else { + return null; + } + } + + getInputTensor(name: string): runtime_v1.Tensor | null { + if (name in this.inputs) { + return this.inputs[name]; + } else { + return null; + } + } + + setOutputTensor(name: string, tensor: runtime_v1.Tensor): void { + this.outputs[name] = tensor; + } + + getGlobalInput(name: string): runtime_v1.Tensor | null { + throw new Error("Method not implemented."); + } + + setGlobalOutput(name: string, tensor: runtime_v1.Tensor): void { + throw new Error("Method not implemented."); + } +} diff --git a/bindings/web/rune/src/proc_blocks/ProcBlock.ts b/bindings/web/rune/src/proc_blocks/ProcBlock.ts new file mode 100644 index 0000000000..5ff9fdb940 --- /dev/null +++ b/bindings/web/rune/src/proc_blocks/ProcBlock.ts @@ -0,0 +1,182 @@ +import { proc_block_v1, runtime_v1 } from "@hotg-ai/rune-wit-files"; +import { Logger } from "pino"; +import type { Metadata, Tensors } from "."; +import { GraphContext, HostFunctions, KernelContext } from "./HostFunctions"; + +type ProcBlockBinary = Parameters[0]; + +/** + * An executable proc-block. + */ +export class ProcBlock { + private constructor( + private hostFunctions: HostFunctions, + private instance: proc_block_v1.ProcBlockV1, + private logger: Logger + ) {} + + /** + * Load a ProcBlock from a WebAssembly module. + * + * @param wasm Something that can be used to instantiate a WebAssembly module. + * @param rootLogger A logger that this ProcBlock can use. + * @returns + */ + static async load( + wasm: ProcBlockBinary, + rootLogger: Logger + ): Promise { + // Note: We want the host functions logger to have a different "name" field + // to the ProcBlock object. + const hostFunctionsLogger = rootLogger.child({ name: "HostFunctions" }); + const logger = rootLogger.child({ name: "ProcBlock" }); + + logger.info("Loading the proc-block"); + const start = Date.now(); + + const procBlock = new proc_block_v1.ProcBlockV1(); + const imports: any = {}; + + const hostFunctions = new HostFunctions(hostFunctionsLogger); + runtime_v1.addRuntimeV1ToImports( + imports, + hostFunctions, + (name) => procBlock.instance.exports[name] + ); + + await procBlock.instantiate(wasm, imports); + + const durationMs = Date.now() - start; + rootLogger.debug({ durationMs }, "Finished loading the proc-block"); + + return new ProcBlock(hostFunctions, procBlock, logger); + } + + /** + * Extract metadata from the proc-block. + */ + metadata(): Metadata { + this.hostFunctions.metadata = undefined; + this.instance.registerMetadata(); + + if (!this.hostFunctions.metadata) { + throw new Error("The proc-block didn't register any metadata"); + } + return this.hostFunctions.metadata; + } + + /** + * Given the provided set of arguments, what would this proc-block's input + * and output tensors be? + */ + graph(args: Record): Tensors { + this.logger.debug({ args }, "Calling the graph function"); + + const ctx = new GraphContext(args); + this.hostFunctions.graph = ctx; + const result = this.instance.graph(""); + + if (result.tag == "err") { + handleGraphError(result.val); + } + + const { inputs, outputs } = ctx; + return { inputs, outputs }; + } + + /** + * Evaluate this proc-block. + * + * @param args Key-value arguments that control the proc-block's behaviour. + * @param inputs Input tensors. + * @returns + */ + evaluate( + inputs: Record, + args: Record + ): Record { + this.logger.debug( + { args, inputs: Object.keys(inputs) }, + "Evaluating a proc-block" + ); + + const ctx = new KernelContext(args, inputs); + this.hostFunctions.kernel = ctx; + + const result = this.instance.kernel(""); + + if (result.tag == "err") { + this.logger.error( + { error: result.val }, + "Evaluating the proc-block failed" + ); + handleKernelError(result.val); + } + + return ctx.outputs; + } +} + +function handleGraphError(err: proc_block_v1.GraphError): never { + switch (err.tag) { + case "invalid-argument": + const { name, reason } = err.val; + handleInvalidArgument(name, reason); + + case "missing-context": + throw new Error("The proc-block couldn't access the context object"); + + case "other": + throw new Error(err.val); + } +} + +function handleKernelError(err: proc_block_v1.KernelError): never { + switch (err.tag) { + case "invalid-input": + const { name, reason } = err.val; + handleInvalidInput(name, reason); + + default: + handleGraphError(err); + } +} + +function handleInvalidInput( + name: string, + reason: proc_block_v1.BadInputReason +): never { + switch (reason.tag) { + case "invalid-value": + throw new Error( + `The "${name}" input had an invalid value: ${reason.val}` + ); + + case "unsupported-shape": + throw new Error(`The "${name}" input had a the wrong shape`); + + case "not-found": + throw new Error(`The "${name}" input wasn't set`); + + case "other": + throw new Error(`The "${name}" input was invalid: ${reason.val}`); + } +} + +function handleInvalidArgument( + name: string, + reason: proc_block_v1.BadArgumentReason +): never { + switch (reason.tag) { + case "invalid-value": + throw new Error( + `The "${name}" argument had an invalid value: ${reason.val}` + ); + + case "not-found": + throw new Error(`The "${name}" argument wasn't set`); + + case "other": + throw new Error(`The "${name}" argument was invalid: ${reason.val}`); + } +} diff --git a/bindings/web/rune/src/proc_blocks/index.ts b/bindings/web/rune/src/proc_blocks/index.ts new file mode 100644 index 0000000000..91fd0c7713 --- /dev/null +++ b/bindings/web/rune/src/proc_blocks/index.ts @@ -0,0 +1,177 @@ +export { ProcBlock } from "./ProcBlock"; + +import { runtime_v1 } from "@hotg-ai/rune-wit-files"; + +/** + * Proc-block metadata. + */ +export type Metadata = { + /** + * The proc-block's human-friendly name. + */ + name: string; + /** + * A semver-compliant version number. + */ + version: string; + /** + * A long-form description of what this proc-block does, formatted as markdown. + */ + description?: string; + /** + * A link to the proc-block's source code. + */ + repository?: string; + /** + * A link to some web page associated with the proc-block. + */ + homepage?: string; + /** + * Arbitrary tags that can be used for filtering and searching. + */ + tags: string[]; + /** + * Arguments this proc-block accepts. + */ + arguments: ArgumentMetadata[]; + /** + * The tensors this proc-block will expect as inputs. + */ + inputs: TensorMetadata[]; + /** + * The tensors this proc-block will produce as outputs. + */ + outputs: TensorMetadata[]; +}; + +/** + * Information about a tensor's name and constraints about its general shape. + */ +export type TensorDescriptor = { + /** + * The name associated with this tensor. + */ + name: string; + /** + * The type of elements this tensor will contain. + */ + elementType: runtime_v1.ElementType; + /** + * Constraints on the tensor's dimensions (a 2D tensor with fixed dimensions, + * a 1D tensor of arbitrary length, a tensor that can have any number of + * dimensions it wants, etc.). + */ + dimensions: runtime_v1.Dimensions; +}; + +/** + * Metadata about a particular tensor. + */ +export type TensorMetadata = { + /** + * The name used by the proc-block when referring to this tensor. + */ + name: string; + /** + * A long-form description of this tensor, formatted as markdown. + */ + description?: string; + hints: TensorHint[]; +}; + +/** + * Metadata around a proc-block argument. + */ +export type ArgumentMetadata = { + /** + * The name the proc-block will expect to find. + */ + name: string; + /** + * A long-form description of what this argument does, formatted as markdown. + */ + description?: string; + /** + * The value used by this if this argument isn't provided. + */ + defaultValue?: string; + /** + * Arbitrary hints that can be used to understand more about the argument. + */ + hints: ArgumentHint[]; +}; + +/** + * The argument has a numeric value within a particular range. + */ +export type NumberInRange = { + type: "number-in-range"; + min: string; + max: string; +}; + +/** + * The argument should have one of the values within a set of possible values. + */ +export type StringEnum = { + type: "string-enum"; + possibleValues: string[]; +}; + +/** + * The argument is a non-negative number. + */ +export type NonNegativeNumber = { + type: "non-negative-number"; +}; + +/** + * The "type" of argument this may take. + * + * You can use this as a suggestion when trying to choose which widget would be + * most appropriate when a user is inputting the argument value (e.g. you might + * want to use a