diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..09bc027 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,43 @@ +name: Publish Package to npmjs +on: + release: + types: [published] +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-node@v3 + with: + node-version: "16.x" + registry-url: "https://registry.npmjs.org" + - uses: pnpm/action-setup@v2 + name: Install pnpm + id: pnpm-install + with: + version: 8 + run_install: false + + - name: Get pnpm store directory + id: pnpm-cache + shell: bash + run: | + echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT + + - uses: actions/cache@v3 + name: Setup pnpm cache + with: + path: ${{ steps.pnpm-cache.outputs.STORE_PATH }} + key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }} + restore-keys: | + ${{ runner.os }}-pnpm-store- + + - name: Install dependencies + run: pnpm install + + - name: Build + run: pnpm build + - name: Publish to npmjs + run: pnpm publish + env: + NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8334b44 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +node_modules +**/*.tsbuildinfo +**/**/.next +**/**/node_modules diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..052f695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Matt Rickard + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3fe525b --- /dev/null +++ b/README.md @@ -0,0 +1,126 @@ +# @react-llm/headless + +Easy-to-use headless React Hooks to run LLMs in the browser with WebGPU. As simple as `useLLM()`. + +### [**Live Demo**](https://chat.matt-rickard.com) + +![image](assets/demo.webp) + +**Features**: + +* Supports [Vicuna 13B](https://lmsys.org/blog/2023-03-30-vicuna/) +* Use custom system prompts and "user:"/"assistant:" role names +* Completion options like `max tokens` and `stop sequences` +* No data leaves the browser. Accelerated via WebGPU. +* Hooks built to 'Bring your own UI' +* Persistent storage for conversations in browser storage. Hooks for loading and saving conversations. +* Model caching for faster subsequent loads + +## Installation + +```bash +npm install @react-llm/headless +``` + + +## **useLLM** API +### Types +```typescript +// Model Initialization +init: () => void; + +// Model Generation +send: (msg: string, maxTokens: number, stopSequences: string[]) => void; +onMessage: (msg: GenerateTextResponse) => void; +setOnMessage: (cb: (msg: GenerateTextResponse) => void) => void; + +// Model Status +loadingStatus: InitProgressReport; +isGenerating: boolean; +gpuDevice: GPUDeviceInfo; + +// Model Configuration +userRoleName: string; +setUserRoleName: (roleName: string) => void; +assistantRoleName: string; +setAssistantRoleName: (roleName: string) => void; + +// Conversation Management +conversation: Conversation | undefined; +allConversations: Conversation[] | undefined; +createConversation: (title?: string, prompt?: string) => void; +setConversationId: (conversationId: string) => void; +deleteConversation: (conversationId: string) => void; +deleteAllConversations: () => void; +deleteMessages: () => void; +setConversationTitle: (conversationId: string, title: string) => void; +``` + +### Hooks +```typescript +import useLLM from '@react-llm/headless'; + +const MyComponent = () => { + const { + conversation, + allConversations, + loadingStatus, + isGenerating, + createConversation, + setConversationId, + deleteConversation, + deleteAllConversations, + deleteMessages, + setConversationTitle, + onMessage, + setOnMessage, + userRoleName, + setUserRoleName, + assistantRoleName, + setAssistantRoleName, + gpuDevice, + send, + init, + } = useLLM(); + + // Component logic... + + return null; +}; +``` + + +### Packages + +* `@react-llm/headless` - Headless React Hooks for running LLMs in the browser +* `@react-llm/retro-ui` - Retro-themed UI for the hooks + +## How does it work? + +This library is a set of React Hooks that provide a simple interface to run LLMs in the browser. It uses Vicuna 13B. + +* SentencePiece tokenizer (compiled for the browser via Emscripten) +* Vicuna 13B (transformed to Apache TVM format) +* Apache TVM and MLC Relax (compiled for the browser via Emscripten) +* Off-the-main-thread WebWorker to run the model (bundled with the library) + + +The model, tokenizer, and TVM runtime are loaded from a CDN (huggingface). The model is cached in browser storage for faster subsequent loads. + + + + +### Example +See [packages/retro-ui](packages/retro-ui) for the full demo code. This is a simple example of how to use the hooks. To run it, after cloning the repo, + +```bash +cd packages/retro-ui +pnpm install +pnpm dev +``` + + +### License +MIT + +The code under `packages/headless/worker/lib/tvm` is licensed under Apache 2.0. \ No newline at end of file diff --git a/assets/demo.webp b/assets/demo.webp new file mode 100644 index 0000000..22d681d Binary files /dev/null and b/assets/demo.webp differ diff --git a/package.json b/package.json new file mode 100644 index 0000000..6e7dc70 --- /dev/null +++ b/package.json @@ -0,0 +1,25 @@ +{ + "name": "@react-llm/workspace", + "version": "0.0.1", + "type": "module", + "main": "dist/bundle.cjs.js", + "module": "dist/bundle.esm.js", + "author": "Matt Rickard ", + "license": "MIT", + "private": true, + "workspaces": [ + "packages/headless", + "packages/retro-ui" + ], + "scripts": { + "publish": "pnpm publish --access public", + "build": "pnpm recursive run build" + }, + "devDependencies": { + "typescript": "^5.0.4" + }, + "dependencies": { + "react95": "^4.0.0", + "styled-components": "^5.3.10" + } +} diff --git a/packages/headless/.eslintrc.json b/packages/headless/.eslintrc.json new file mode 100644 index 0000000..a4d4588 --- /dev/null +++ b/packages/headless/.eslintrc.json @@ -0,0 +1,5 @@ +{ + "extends": [ + "@typescript-eslint/no-unused-vars" + ] +} \ No newline at end of file diff --git a/packages/headless/.gitignore b/packages/headless/.gitignore new file mode 100644 index 0000000..8f322f0 --- /dev/null +++ b/packages/headless/.gitignore @@ -0,0 +1,35 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts diff --git a/packages/headless/LICENSE b/packages/headless/LICENSE new file mode 100644 index 0000000..052f695 --- /dev/null +++ b/packages/headless/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Matt Rickard + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/headless/dist/index.js b/packages/headless/dist/index.js new file mode 100644 index 0000000..ca7b287 --- /dev/null +++ b/packages/headless/dist/index.js @@ -0,0 +1,1156 @@ +import require$$0, { useDebugValue, useState, useEffect, useRef, useCallback, createContext, useContext } from 'react'; +import { v as v4, _ as __spreadArray, a as __assign, d as detectGPUDevice, w as wrap, p as proxy } from './v4-2119d9d5.js'; + +const createStoreImpl = createState => { + let state; + const listeners = /* @__PURE__ */new Set(); + const setState = (partial, replace) => { + const nextState = typeof partial === "function" ? partial(state) : partial; + if (!Object.is(nextState, state)) { + const previousState = state; + state = (replace != null ? replace : typeof nextState !== "object") ? nextState : Object.assign({}, state, nextState); + listeners.forEach(listener => listener(state, previousState)); + } + }; + const getState = () => state; + const subscribe = listener => { + listeners.add(listener); + return () => listeners.delete(listener); + }; + const destroy = () => { + if ((import.meta.env && import.meta.env.MODE) !== "production") { + console.warn("[DEPRECATED] The `destroy` method will be unsupported in a future version. Instead use unsubscribe function returned by subscribe. Everything will be garbage-collected if store is garbage-collected."); + } + listeners.clear(); + }; + const api = { + setState, + getState, + subscribe, + destroy + }; + state = createState(setState, getState, api); + return api; +}; +const createStore = createState => createState ? createStoreImpl(createState) : createStoreImpl; + +function getDefaultExportFromCjs (x) { + return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; +} + +var withSelector = {exports: {}}; + +var withSelector_production_min = {}; + +var shim = {exports: {}}; + +var useSyncExternalStoreShim_production_min = {}; + +/** + * @license React + * use-sync-external-store-shim.production.min.js + * + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +var hasRequiredUseSyncExternalStoreShim_production_min; + +function requireUseSyncExternalStoreShim_production_min () { + if (hasRequiredUseSyncExternalStoreShim_production_min) return useSyncExternalStoreShim_production_min; + hasRequiredUseSyncExternalStoreShim_production_min = 1; + + var e = require$$0; + function h(a, b) { + return a === b && (0 !== a || 1 / a === 1 / b) || a !== a && b !== b; + } + var k = "function" === typeof Object.is ? Object.is : h, + l = e.useState, + m = e.useEffect, + n = e.useLayoutEffect, + p = e.useDebugValue; + function q(a, b) { + var d = b(), + f = l({ + inst: { + value: d, + getSnapshot: b + } + }), + c = f[0].inst, + g = f[1]; + n(function () { + c.value = d; + c.getSnapshot = b; + r(c) && g({ + inst: c + }); + }, [a, d, b]); + m(function () { + r(c) && g({ + inst: c + }); + return a(function () { + r(c) && g({ + inst: c + }); + }); + }, [a]); + p(d); + return d; + } + function r(a) { + var b = a.getSnapshot; + a = a.value; + try { + var d = b(); + return !k(a, d); + } catch (f) { + return !0; + } + } + function t(a, b) { + return b(); + } + var u = "undefined" === typeof window || "undefined" === typeof window.document || "undefined" === typeof window.document.createElement ? t : q; + useSyncExternalStoreShim_production_min.useSyncExternalStore = void 0 !== e.useSyncExternalStore ? e.useSyncExternalStore : u; + return useSyncExternalStoreShim_production_min; +} + +var useSyncExternalStoreShim_development = {}; + +/** + * @license React + * use-sync-external-store-shim.development.js + * + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +var hasRequiredUseSyncExternalStoreShim_development; + +function requireUseSyncExternalStoreShim_development () { + if (hasRequiredUseSyncExternalStoreShim_development) return useSyncExternalStoreShim_development; + hasRequiredUseSyncExternalStoreShim_development = 1; + + if (process.env.NODE_ENV !== "production") { + (function () { + + /* global __REACT_DEVTOOLS_GLOBAL_HOOK__ */ + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== 'undefined' && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart === 'function') { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart(new Error()); + } + var React = require$$0; + var ReactSharedInternals = React.__SECRET_INTERNALS_DO_NOT_USE_OR_YOU_WILL_BE_FIRED; + function error(format) { + { + { + for (var _len2 = arguments.length, args = new Array(_len2 > 1 ? _len2 - 1 : 0), _key2 = 1; _key2 < _len2; _key2++) { + args[_key2 - 1] = arguments[_key2]; + } + printWarning('error', format, args); + } + } + } + function printWarning(level, format, args) { + // When changing this logic, you might want to also + // update consoleWithStackDev.www.js as well. + { + var ReactDebugCurrentFrame = ReactSharedInternals.ReactDebugCurrentFrame; + var stack = ReactDebugCurrentFrame.getStackAddendum(); + if (stack !== '') { + format += '%s'; + args = args.concat([stack]); + } // eslint-disable-next-line react-internal/safe-string-coercion + + var argsWithFormat = args.map(function (item) { + return String(item); + }); // Careful: RN currently depends on this prefix + + argsWithFormat.unshift('Warning: ' + format); // We intentionally don't use spread (or .apply) directly because it + // breaks IE9: https://github.com/facebook/react/issues/13610 + // eslint-disable-next-line react-internal/no-production-logging + + Function.prototype.apply.call(console[level], console, argsWithFormat); + } + } + + /** + * inlined Object.is polyfill to avoid requiring consumers ship their own + * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/is + */ + function is(x, y) { + return x === y && (x !== 0 || 1 / x === 1 / y) || x !== x && y !== y // eslint-disable-line no-self-compare + ; + } + + var objectIs = typeof Object.is === 'function' ? Object.is : is; + + // dispatch for CommonJS interop named imports. + + var useState = React.useState, + useEffect = React.useEffect, + useLayoutEffect = React.useLayoutEffect, + useDebugValue = React.useDebugValue; + var didWarnOld18Alpha = false; + var didWarnUncachedGetSnapshot = false; // Disclaimer: This shim breaks many of the rules of React, and only works + // because of a very particular set of implementation details and assumptions + // -- change any one of them and it will break. The most important assumption + // is that updates are always synchronous, because concurrent rendering is + // only available in versions of React that also have a built-in + // useSyncExternalStore API. And we only use this shim when the built-in API + // does not exist. + // + // Do not assume that the clever hacks used by this hook also work in general. + // The point of this shim is to replace the need for hacks by other libraries. + + function useSyncExternalStore(subscribe, getSnapshot, + // Note: The shim does not use getServerSnapshot, because pre-18 versions of + // React do not expose a way to check if we're hydrating. So users of the shim + // will need to track that themselves and return the correct value + // from `getSnapshot`. + getServerSnapshot) { + { + if (!didWarnOld18Alpha) { + if (React.startTransition !== undefined) { + didWarnOld18Alpha = true; + error('You are using an outdated, pre-release alpha of React 18 that ' + 'does not support useSyncExternalStore. The ' + 'use-sync-external-store shim will not work correctly. Upgrade ' + 'to a newer pre-release.'); + } + } + } // Read the current snapshot from the store on every render. Again, this + // breaks the rules of React, and only works here because of specific + // implementation details, most importantly that updates are + // always synchronous. + + var value = getSnapshot(); + { + if (!didWarnUncachedGetSnapshot) { + var cachedValue = getSnapshot(); + if (!objectIs(value, cachedValue)) { + error('The result of getSnapshot should be cached to avoid an infinite loop'); + didWarnUncachedGetSnapshot = true; + } + } + } // Because updates are synchronous, we don't queue them. Instead we force a + // re-render whenever the subscribed state changes by updating an some + // arbitrary useState hook. Then, during render, we call getSnapshot to read + // the current value. + // + // Because we don't actually use the state returned by the useState hook, we + // can save a bit of memory by storing other stuff in that slot. + // + // To implement the early bailout, we need to track some things on a mutable + // object. Usually, we would put that in a useRef hook, but we can stash it in + // our useState hook instead. + // + // To force a re-render, we call forceUpdate({inst}). That works because the + // new object always fails an equality check. + + var _useState = useState({ + inst: { + value: value, + getSnapshot: getSnapshot + } + }), + inst = _useState[0].inst, + forceUpdate = _useState[1]; // Track the latest getSnapshot function with a ref. This needs to be updated + // in the layout phase so we can access it during the tearing check that + // happens on subscribe. + + useLayoutEffect(function () { + inst.value = value; + inst.getSnapshot = getSnapshot; // Whenever getSnapshot or subscribe changes, we need to check in the + // commit phase if there was an interleaved mutation. In concurrent mode + // this can happen all the time, but even in synchronous mode, an earlier + // effect may have mutated the store. + + if (checkIfSnapshotChanged(inst)) { + // Force a re-render. + forceUpdate({ + inst: inst + }); + } + }, [subscribe, value, getSnapshot]); + useEffect(function () { + // Check for changes right before subscribing. Subsequent changes will be + // detected in the subscription handler. + if (checkIfSnapshotChanged(inst)) { + // Force a re-render. + forceUpdate({ + inst: inst + }); + } + var handleStoreChange = function () { + // TODO: Because there is no cross-renderer API for batching updates, it's + // up to the consumer of this library to wrap their subscription event + // with unstable_batchedUpdates. Should we try to detect when this isn't + // the case and print a warning in development? + // The store changed. Check if the snapshot changed since the last time we + // read from the store. + if (checkIfSnapshotChanged(inst)) { + // Force a re-render. + forceUpdate({ + inst: inst + }); + } + }; // Subscribe to the store and return a clean-up function. + + return subscribe(handleStoreChange); + }, [subscribe]); + useDebugValue(value); + return value; + } + function checkIfSnapshotChanged(inst) { + var latestGetSnapshot = inst.getSnapshot; + var prevValue = inst.value; + try { + var nextValue = latestGetSnapshot(); + return !objectIs(prevValue, nextValue); + } catch (error) { + return true; + } + } + function useSyncExternalStore$1(subscribe, getSnapshot, getServerSnapshot) { + // Note: The shim does not use getServerSnapshot, because pre-18 versions of + // React do not expose a way to check if we're hydrating. So users of the shim + // will need to track that themselves and return the correct value + // from `getSnapshot`. + return getSnapshot(); + } + var canUseDOM = !!(typeof window !== 'undefined' && typeof window.document !== 'undefined' && typeof window.document.createElement !== 'undefined'); + var isServerEnvironment = !canUseDOM; + var shim = isServerEnvironment ? useSyncExternalStore$1 : useSyncExternalStore; + var useSyncExternalStore$2 = React.useSyncExternalStore !== undefined ? React.useSyncExternalStore : shim; + useSyncExternalStoreShim_development.useSyncExternalStore = useSyncExternalStore$2; + /* global __REACT_DEVTOOLS_GLOBAL_HOOK__ */ + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== 'undefined' && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop === 'function') { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop(new Error()); + } + })(); + } + return useSyncExternalStoreShim_development; +} + +var hasRequiredShim; + +function requireShim () { + if (hasRequiredShim) return shim.exports; + hasRequiredShim = 1; + + if (process.env.NODE_ENV === 'production') { + shim.exports = requireUseSyncExternalStoreShim_production_min(); + } else { + shim.exports = requireUseSyncExternalStoreShim_development(); + } + return shim.exports; +} + +/** + * @license React + * use-sync-external-store-shim/with-selector.production.min.js + * + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +var hasRequiredWithSelector_production_min; + +function requireWithSelector_production_min () { + if (hasRequiredWithSelector_production_min) return withSelector_production_min; + hasRequiredWithSelector_production_min = 1; + + var h = require$$0, + n = requireShim(); + function p(a, b) { + return a === b && (0 !== a || 1 / a === 1 / b) || a !== a && b !== b; + } + var q = "function" === typeof Object.is ? Object.is : p, + r = n.useSyncExternalStore, + t = h.useRef, + u = h.useEffect, + v = h.useMemo, + w = h.useDebugValue; + withSelector_production_min.useSyncExternalStoreWithSelector = function (a, b, e, l, g) { + var c = t(null); + if (null === c.current) { + var f = { + hasValue: !1, + value: null + }; + c.current = f; + } else f = c.current; + c = v(function () { + function a(a) { + if (!c) { + c = !0; + d = a; + a = l(a); + if (void 0 !== g && f.hasValue) { + var b = f.value; + if (g(b, a)) return k = b; + } + return k = a; + } + b = k; + if (q(d, a)) return b; + var e = l(a); + if (void 0 !== g && g(b, e)) return b; + d = a; + return k = e; + } + var c = !1, + d, + k, + m = void 0 === e ? null : e; + return [function () { + return a(b()); + }, null === m ? void 0 : function () { + return a(m()); + }]; + }, [b, e, l, g]); + var d = r(a, c[0], c[1]); + u(function () { + f.hasValue = !0; + f.value = d; + }, [d]); + w(d); + return d; + }; + return withSelector_production_min; +} + +var withSelector_development = {}; + +/** + * @license React + * use-sync-external-store-shim/with-selector.development.js + * + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +var hasRequiredWithSelector_development; + +function requireWithSelector_development () { + if (hasRequiredWithSelector_development) return withSelector_development; + hasRequiredWithSelector_development = 1; + + if (process.env.NODE_ENV !== "production") { + (function () { + + /* global __REACT_DEVTOOLS_GLOBAL_HOOK__ */ + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== 'undefined' && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart === 'function') { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart(new Error()); + } + var React = require$$0; + var shim = requireShim(); + + /** + * inlined Object.is polyfill to avoid requiring consumers ship their own + * https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/is + */ + function is(x, y) { + return x === y && (x !== 0 || 1 / x === 1 / y) || x !== x && y !== y // eslint-disable-line no-self-compare + ; + } + + var objectIs = typeof Object.is === 'function' ? Object.is : is; + var useSyncExternalStore = shim.useSyncExternalStore; + + // for CommonJS interop. + + var useRef = React.useRef, + useEffect = React.useEffect, + useMemo = React.useMemo, + useDebugValue = React.useDebugValue; // Same as useSyncExternalStore, but supports selector and isEqual arguments. + + function useSyncExternalStoreWithSelector(subscribe, getSnapshot, getServerSnapshot, selector, isEqual) { + // Use this to track the rendered snapshot. + var instRef = useRef(null); + var inst; + if (instRef.current === null) { + inst = { + hasValue: false, + value: null + }; + instRef.current = inst; + } else { + inst = instRef.current; + } + var _useMemo = useMemo(function () { + // Track the memoized state using closure variables that are local to this + // memoized instance of a getSnapshot function. Intentionally not using a + // useRef hook, because that state would be shared across all concurrent + // copies of the hook/component. + var hasMemo = false; + var memoizedSnapshot; + var memoizedSelection; + var memoizedSelector = function (nextSnapshot) { + if (!hasMemo) { + // The first time the hook is called, there is no memoized result. + hasMemo = true; + memoizedSnapshot = nextSnapshot; + var _nextSelection = selector(nextSnapshot); + if (isEqual !== undefined) { + // Even if the selector has changed, the currently rendered selection + // may be equal to the new selection. We should attempt to reuse the + // current value if possible, to preserve downstream memoizations. + if (inst.hasValue) { + var currentSelection = inst.value; + if (isEqual(currentSelection, _nextSelection)) { + memoizedSelection = currentSelection; + return currentSelection; + } + } + } + memoizedSelection = _nextSelection; + return _nextSelection; + } // We may be able to reuse the previous invocation's result. + + // We may be able to reuse the previous invocation's result. + var prevSnapshot = memoizedSnapshot; + var prevSelection = memoizedSelection; + if (objectIs(prevSnapshot, nextSnapshot)) { + // The snapshot is the same as last time. Reuse the previous selection. + return prevSelection; + } // The snapshot has changed, so we need to compute a new selection. + + // The snapshot has changed, so we need to compute a new selection. + var nextSelection = selector(nextSnapshot); // If a custom isEqual function is provided, use that to check if the data + // has changed. If it hasn't, return the previous selection. That signals + // to React that the selections are conceptually equal, and we can bail + // out of rendering. + + // If a custom isEqual function is provided, use that to check if the data + // has changed. If it hasn't, return the previous selection. That signals + // to React that the selections are conceptually equal, and we can bail + // out of rendering. + if (isEqual !== undefined && isEqual(prevSelection, nextSelection)) { + return prevSelection; + } + memoizedSnapshot = nextSnapshot; + memoizedSelection = nextSelection; + return nextSelection; + }; // Assigning this to a constant so that Flow knows it can't change. + + // Assigning this to a constant so that Flow knows it can't change. + var maybeGetServerSnapshot = getServerSnapshot === undefined ? null : getServerSnapshot; + var getSnapshotWithSelector = function () { + return memoizedSelector(getSnapshot()); + }; + var getServerSnapshotWithSelector = maybeGetServerSnapshot === null ? undefined : function () { + return memoizedSelector(maybeGetServerSnapshot()); + }; + return [getSnapshotWithSelector, getServerSnapshotWithSelector]; + }, [getSnapshot, getServerSnapshot, selector, isEqual]), + getSelection = _useMemo[0], + getServerSelection = _useMemo[1]; + var value = useSyncExternalStore(subscribe, getSelection, getServerSelection); + useEffect(function () { + inst.hasValue = true; + inst.value = value; + }, [value]); + useDebugValue(value); + return value; + } + withSelector_development.useSyncExternalStoreWithSelector = useSyncExternalStoreWithSelector; + /* global __REACT_DEVTOOLS_GLOBAL_HOOK__ */ + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== 'undefined' && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop === 'function') { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop(new Error()); + } + })(); + } + return withSelector_development; +} + +if (process.env.NODE_ENV === 'production') { + withSelector.exports = requireWithSelector_production_min(); +} else { + withSelector.exports = requireWithSelector_development(); +} + +var withSelectorExports = withSelector.exports; +var useSyncExternalStoreExports = /*@__PURE__*/getDefaultExportFromCjs(withSelectorExports); + +const { + useSyncExternalStoreWithSelector +} = useSyncExternalStoreExports; +function useStore$1(api, selector = api.getState, equalityFn) { + const slice = useSyncExternalStoreWithSelector(api.subscribe, api.getState, api.getServerState || api.getState, selector, equalityFn); + useDebugValue(slice); + return slice; +} +const createImpl = createState => { + if ((import.meta.env && import.meta.env.MODE) !== "production" && typeof createState !== "function") { + console.warn("[DEPRECATED] Passing a vanilla store will be unsupported in a future version. Instead use `import { useStore } from 'zustand'`."); + } + const api = typeof createState === "function" ? createStore(createState) : createState; + const useBoundStore = (selector, equalityFn) => useStore$1(api, selector, equalityFn); + Object.assign(useBoundStore, api); + return useBoundStore; +}; +const create = createState => createState ? createImpl(createState) : createImpl; + +function createJSONStorage(getStorage, options) { + let storage; + try { + storage = getStorage(); + } catch (e) { + return; + } + const persistStorage = { + getItem: name => { + var _a; + const parse = str2 => { + if (str2 === null) { + return null; + } + return JSON.parse(str2, options == null ? void 0 : options.reviver); + }; + const str = (_a = storage.getItem(name)) != null ? _a : null; + if (str instanceof Promise) { + return str.then(parse); + } + return parse(str); + }, + setItem: (name, newValue) => storage.setItem(name, JSON.stringify(newValue, options == null ? void 0 : options.replacer)), + removeItem: name => storage.removeItem(name) + }; + return persistStorage; +} +const toThenable = fn => input => { + try { + const result = fn(input); + if (result instanceof Promise) { + return result; + } + return { + then(onFulfilled) { + return toThenable(onFulfilled)(result); + }, + catch(_onRejected) { + return this; + } + }; + } catch (e) { + return { + then(_onFulfilled) { + return this; + }, + catch(onRejected) { + return toThenable(onRejected)(e); + } + }; + } +}; +const oldImpl = (config, baseOptions) => (set, get, api) => { + let options = { + getStorage: () => localStorage, + serialize: JSON.stringify, + deserialize: JSON.parse, + partialize: state => state, + version: 0, + merge: (persistedState, currentState) => ({ + ...currentState, + ...persistedState + }), + ...baseOptions + }; + let hasHydrated = false; + const hydrationListeners = /* @__PURE__ */new Set(); + const finishHydrationListeners = /* @__PURE__ */new Set(); + let storage; + try { + storage = options.getStorage(); + } catch (e) {} + if (!storage) { + return config((...args) => { + console.warn(`[zustand persist middleware] Unable to update item '${options.name}', the given storage is currently unavailable.`); + set(...args); + }, get, api); + } + const thenableSerialize = toThenable(options.serialize); + const setItem = () => { + const state = options.partialize({ + ...get() + }); + let errorInSync; + const thenable = thenableSerialize({ + state, + version: options.version + }).then(serializedValue => storage.setItem(options.name, serializedValue)).catch(e => { + errorInSync = e; + }); + if (errorInSync) { + throw errorInSync; + } + return thenable; + }; + const savedSetState = api.setState; + api.setState = (state, replace) => { + savedSetState(state, replace); + void setItem(); + }; + const configResult = config((...args) => { + set(...args); + void setItem(); + }, get, api); + let stateFromStorage; + const hydrate = () => { + var _a; + if (!storage) return; + hasHydrated = false; + hydrationListeners.forEach(cb => cb(get())); + const postRehydrationCallback = ((_a = options.onRehydrateStorage) == null ? void 0 : _a.call(options, get())) || void 0; + return toThenable(storage.getItem.bind(storage))(options.name).then(storageValue => { + if (storageValue) { + return options.deserialize(storageValue); + } + }).then(deserializedStorageValue => { + if (deserializedStorageValue) { + if (typeof deserializedStorageValue.version === "number" && deserializedStorageValue.version !== options.version) { + if (options.migrate) { + return options.migrate(deserializedStorageValue.state, deserializedStorageValue.version); + } + console.error(`State loaded from storage couldn't be migrated since no migrate function was provided`); + } else { + return deserializedStorageValue.state; + } + } + }).then(migratedState => { + var _a2; + stateFromStorage = options.merge(migratedState, (_a2 = get()) != null ? _a2 : configResult); + set(stateFromStorage, true); + return setItem(); + }).then(() => { + postRehydrationCallback == null ? void 0 : postRehydrationCallback(stateFromStorage, void 0); + hasHydrated = true; + finishHydrationListeners.forEach(cb => cb(stateFromStorage)); + }).catch(e => { + postRehydrationCallback == null ? void 0 : postRehydrationCallback(void 0, e); + }); + }; + api.persist = { + setOptions: newOptions => { + options = { + ...options, + ...newOptions + }; + if (newOptions.getStorage) { + storage = newOptions.getStorage(); + } + }, + clearStorage: () => { + storage == null ? void 0 : storage.removeItem(options.name); + }, + getOptions: () => options, + rehydrate: () => hydrate(), + hasHydrated: () => hasHydrated, + onHydrate: cb => { + hydrationListeners.add(cb); + return () => { + hydrationListeners.delete(cb); + }; + }, + onFinishHydration: cb => { + finishHydrationListeners.add(cb); + return () => { + finishHydrationListeners.delete(cb); + }; + } + }; + hydrate(); + return stateFromStorage || configResult; +}; +const newImpl = (config, baseOptions) => (set, get, api) => { + let options = { + storage: createJSONStorage(() => localStorage), + partialize: state => state, + version: 0, + merge: (persistedState, currentState) => ({ + ...currentState, + ...persistedState + }), + ...baseOptions + }; + let hasHydrated = false; + const hydrationListeners = /* @__PURE__ */new Set(); + const finishHydrationListeners = /* @__PURE__ */new Set(); + let storage = options.storage; + if (!storage) { + return config((...args) => { + console.warn(`[zustand persist middleware] Unable to update item '${options.name}', the given storage is currently unavailable.`); + set(...args); + }, get, api); + } + const setItem = () => { + const state = options.partialize({ + ...get() + }); + return storage.setItem(options.name, { + state, + version: options.version + }); + }; + const savedSetState = api.setState; + api.setState = (state, replace) => { + savedSetState(state, replace); + void setItem(); + }; + const configResult = config((...args) => { + set(...args); + void setItem(); + }, get, api); + let stateFromStorage; + const hydrate = () => { + var _a, _b; + if (!storage) return; + hasHydrated = false; + hydrationListeners.forEach(cb => { + var _a2; + return cb((_a2 = get()) != null ? _a2 : configResult); + }); + const postRehydrationCallback = ((_b = options.onRehydrateStorage) == null ? void 0 : _b.call(options, (_a = get()) != null ? _a : configResult)) || void 0; + return toThenable(storage.getItem.bind(storage))(options.name).then(deserializedStorageValue => { + if (deserializedStorageValue) { + if (typeof deserializedStorageValue.version === "number" && deserializedStorageValue.version !== options.version) { + if (options.migrate) { + return options.migrate(deserializedStorageValue.state, deserializedStorageValue.version); + } + console.error(`State loaded from storage couldn't be migrated since no migrate function was provided`); + } else { + return deserializedStorageValue.state; + } + } + }).then(migratedState => { + var _a2; + stateFromStorage = options.merge(migratedState, (_a2 = get()) != null ? _a2 : configResult); + set(stateFromStorage, true); + return setItem(); + }).then(() => { + postRehydrationCallback == null ? void 0 : postRehydrationCallback(stateFromStorage, void 0); + stateFromStorage = get(); + hasHydrated = true; + finishHydrationListeners.forEach(cb => cb(stateFromStorage)); + }).catch(e => { + postRehydrationCallback == null ? void 0 : postRehydrationCallback(void 0, e); + }); + }; + api.persist = { + setOptions: newOptions => { + options = { + ...options, + ...newOptions + }; + if (newOptions.storage) { + storage = newOptions.storage; + } + }, + clearStorage: () => { + storage == null ? void 0 : storage.removeItem(options.name); + }, + getOptions: () => options, + rehydrate: () => hydrate(), + hasHydrated: () => hasHydrated, + onHydrate: cb => { + hydrationListeners.add(cb); + return () => { + hydrationListeners.delete(cb); + }; + }, + onFinishHydration: cb => { + finishHydrationListeners.add(cb); + return () => { + finishHydrationListeners.delete(cb); + }; + } + }; + if (!options.skipHydration) { + hydrate(); + } + return stateFromStorage || configResult; +}; +const persistImpl = (config, baseOptions) => { + if ("getStorage" in baseOptions || "serialize" in baseOptions || "deserialize" in baseOptions) { + if ((import.meta.env && import.meta.env.MODE) !== "production") { + console.warn("[DEPRECATED] `getStorage`, `serialize` and `deserialize` options are deprecated. Use `storage` option instead."); + } + return oldImpl(config, baseOptions); + } + return newImpl(config, baseOptions); +}; +const persist = persistImpl; + +var defaultSystemPrompt = "A chat between a curious user and a AI chatbot named SmartestChild on AIM who responds with lowercase, frequent emojis, and 2000s internet abbreviations."; +var useConversationStore = create()(persist(function (set, get) { + var initialConversation = { + id: v4(), + title: "Untitled", + updatedAt: new Date().getTime(), + systemPrompt: defaultSystemPrompt, + createdAt: new Date().getTime(), + messages: [], + }; + return { + conversations: [initialConversation], + currentConversationId: initialConversation.id, + createConversation: function (conversation) { + set(function (state) { + return { + currentConversationId: conversation.id, + conversations: __spreadArray(__spreadArray([], state.conversations, true), [conversation], false), + }; + }); + }, + setConversationTitle: function (conversationId, title) { + set(function (state) { + var conversation = state.conversations.find(function (c) { return c.id === conversationId; }); + if (!conversation) { + return state; + } + return { + conversations: __spreadArray(__spreadArray([], state.conversations.filter(function (c) { return c.id !== conversationId; }), true), [ + __assign(__assign({}, conversation), { title: title }), + ], false), + }; + }); + }, + deleteConversation: function (conversationId) { + set(function (state) { + return { + conversations: state.conversations.filter(function (c) { return c.id !== conversationId; }), + }; + }); + }, + setConversationId: function (conversationId) { + var conversationExists = get().conversations.some(function (c) { return c.id === conversationId; }); + if (!conversationExists) { + throw new Error("Invalid conversation id"); + } + set(function (state) { + return __assign(__assign({}, state), { currentConversationId: conversationId }); + }); + }, + deleteAllConversations: function () { + set(function (state) { + return { + conversations: [], + }; + }); + }, + deleteMessages: function (conversationId) { + set(function (state) { + var conversation = state.conversations.find(function (c) { return c.id === conversationId; }); + if (!conversation) { + return state; + } + return { + conversations: __spreadArray(__spreadArray([], state.conversations.filter(function (c) { return c.id !== conversationId; }), true), [ + __assign(__assign({}, conversation), { updatedAt: new Date().getTime(), messages: [] }), + ], false), + }; + }); + }, + getConversation: function (conversationId) { + return get().conversations.find(function (c) { return c.id === conversationId; }); + }, + getAllConversations: function () { + return get().conversations; + }, + addMessage: function (conversationId, message) { + set(function (state) { + var conversation = state.conversations.find(function (c) { return c.id === conversationId; }); + if (!conversation) { + return state; + } + var existingMessage = conversation.messages.find(function (m) { return m.id === message.id; }); + if (existingMessage) { + // Update message + return { + conversations: __spreadArray(__spreadArray([], state.conversations.filter(function (c) { return c.id !== conversationId; }), true), [ + __assign(__assign({}, conversation), { updatedAt: new Date().getTime(), messages: __spreadArray(__spreadArray([], conversation.messages.filter(function (m) { return m.id !== message.id; }), true), [ + message, + ], false) }), + ], false), + }; + } + // Add message + return { + conversations: __spreadArray(__spreadArray([], state.conversations.filter(function (c) { return c.id !== conversationId; }), true), [ + __assign(__assign({}, conversation), { updatedAt: new Date().getTime(), messages: __spreadArray(__spreadArray([], conversation.messages, true), [message], false) }), + ], false), + }; + }); + }, + }; +}, { + name: "chat-store", + getStorage: function () { return sessionStorage; }, +})); + +// https://github.com/pmndrs/zustand/blob/65d2bc0660ab0d542cf9f97a3b004754ffa73f3e/docs/integrations/persisting-store-data.md?plain=1#L471-L488 +var useStore = function (store, callback) { + var result = store(callback); + var _a = useState(), data = _a[0], setData = _a[1]; + useEffect(function () { + setData(result); + }, [result]); + return data; +}; + +var initialProgress = { + type: "init", + progress: 0, + timeElapsed: 0, + currentChunk: 0, + totalChunks: 0, + fetchedBytes: 0, + totalBytes: 0, +}; +var useLLMContext = function () { + var _a = useState(initialProgress), loadingStatus = _a[0], setLoadingStatus = _a[1]; + var _b = useState(false), isGenerating = _b[0], setIsGenerating = _b[1]; + var workerRef = useRef(); + var cStore = useStore(useConversationStore, function (state) { return state; }); + var _c = useState("user"), userRoleName = _c[0], setUserRoleName = _c[1]; + var _d = useState("assistant"), assistantRoleName = _d[0], setAssistantRoleName = _d[1]; + var _e = useState({ + adapter: null, + device: null, + adapterInfo: null, + checked: false, + unsupportedReason: null, + }), gpuDevice = _e[0], setGpuDevice = _e[1]; + useEffect(function () { + if (!gpuDevice || !gpuDevice.checked) { + detectGPUDevice() + .then(function (resp) { + if (resp) { + setGpuDevice({ + unsupportedReason: null, + checked: true, + adapter: resp.adapter, + device: resp.device, + adapterInfo: resp.adapterInfo, + }); + } + else { + setGpuDevice(__assign(__assign({}, gpuDevice), { checked: true, unsupportedReason: "GPU is not supported" })); + } + }) + .catch(function (err) { + setGpuDevice({ + adapter: null, + device: null, + adapterInfo: null, + checked: true, + unsupportedReason: err.message, + }); + }); + } + }, []); + var _f = useState(), onMessage = _f[0], setOnMessage = _f[1]; + var addMessage = useCallback(function (resp) { + if (resp.isFinished) { + setIsGenerating(false); + } + if (onMessage) + onMessage(resp); + cStore === null || cStore === void 0 ? void 0 : cStore.addMessage(cStore === null || cStore === void 0 ? void 0 : cStore.currentConversationId, { + id: resp.requestId, + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + role: assistantRoleName, + text: resp.outputText, + }); + }, [cStore, cStore === null || cStore === void 0 ? void 0 : cStore.currentConversationId, onMessage, setOnMessage]); + useEffect(function () { + if (!workerRef.current) { + workerRef.current = wrap(new Worker(new URL("worker-cc79b531.js", import.meta.url))); + } + }, []); + var send = function (text, maxTokens, stopStrings) { + var _a; + if (maxTokens === void 0) { maxTokens = 100; } + if (stopStrings === void 0) { stopStrings = [userRoleName, assistantRoleName]; } + var currentConversation = cStore === null || cStore === void 0 ? void 0 : cStore.getConversation(cStore === null || cStore === void 0 ? void 0 : cStore.currentConversationId); + if (!currentConversation) { + throw new Error("Invalid conversation id"); + } + currentConversation === null || currentConversation === void 0 ? void 0 : currentConversation.messages.push({ + id: v4(), + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + role: userRoleName, + text: text, + }); + setIsGenerating(true); + (_a = workerRef === null || workerRef === void 0 ? void 0 : workerRef.current) === null || _a === void 0 ? void 0 : _a.generate({ + conversation: currentConversation, + stopTexts: stopStrings, + maxTokens: maxTokens, + assistantRoleName: assistantRoleName, + }, proxy(addMessage)); + }; + return { + conversation: cStore === null || cStore === void 0 ? void 0 : cStore.getConversation(cStore === null || cStore === void 0 ? void 0 : cStore.currentConversationId), + allConversations: cStore === null || cStore === void 0 ? void 0 : cStore.conversations.sort(function (a, b) { return b.updatedAt - a.updatedAt; }), + createConversation: function (title, prompt) { + var id = v4(); + cStore === null || cStore === void 0 ? void 0 : cStore.createConversation({ + id: id, + title: title !== null && title !== void 0 ? title : "Untitled", + systemPrompt: prompt !== null && prompt !== void 0 ? prompt : defaultSystemPrompt, + messages: [], + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + }); + }, + setConversationTitle: function (id, title) { + cStore === null || cStore === void 0 ? void 0 : cStore.setConversationTitle(id, title); + }, + setConversationId: function (id) { + cStore === null || cStore === void 0 ? void 0 : cStore.setConversationId(id); + }, + deleteConversation: function (id) { + cStore === null || cStore === void 0 ? void 0 : cStore.deleteConversation(id); + }, + deleteMessages: function () { return cStore === null || cStore === void 0 ? void 0 : cStore.deleteMessages(cStore === null || cStore === void 0 ? void 0 : cStore.currentConversationId); }, + onMessage: onMessage, + setOnMessage: setOnMessage, + loadingStatus: loadingStatus, + isGenerating: isGenerating, + userRoleName: userRoleName, + setUserRoleName: setUserRoleName, + assistantRoleName: assistantRoleName, + setAssistantRoleName: setAssistantRoleName, + gpuDevice: gpuDevice, + send: send, + init: function () { var _a; return (_a = workerRef === null || workerRef === void 0 ? void 0 : workerRef.current) === null || _a === void 0 ? void 0 : _a.init(proxy(setLoadingStatus)); }, + deleteAllConversations: function () { return cStore === null || cStore === void 0 ? void 0 : cStore.deleteAllConversations(); }, + }; +}; + +var ModelContext = createContext(null); +var ModelProvider = function (_a) { + var children = _a.children; + var LLMValue = useLLMContext(); + return (require$$0.createElement(ModelContext.Provider, { value: LLMValue }, children)); +}; +var useLLM = function () { + var context = useContext(ModelContext); + if (context === null) { + throw new Error("useLLMContext must be used within a LLMProvider"); + } + return context; +}; + +export { ModelProvider, useLLM as default }; diff --git a/packages/headless/dist/types/src/hooks/useConversationStore.d.ts b/packages/headless/dist/types/src/hooks/useConversationStore.d.ts new file mode 100644 index 0000000..f748e87 --- /dev/null +++ b/packages/headless/dist/types/src/hooks/useConversationStore.d.ts @@ -0,0 +1,27 @@ +import { Conversation, Message } from "../types/chat"; +export interface ConversationStore { + conversations: Conversation[]; + currentConversationId: string; + setConversationId: (conversationId: string) => void; + addMessage: (conversationId: string, message: Message) => void; + getConversation: (conversationId: string) => Conversation | undefined; + setConversationTitle: (conversationId: string, title: string) => void; + getAllConversations: () => Conversation[]; + deleteMessages: (conversationId: string) => void; + deleteConversation: (conversationId: string) => void; + createConversation: (conversation: Conversation) => void; + deleteAllConversations: () => void; +} +export declare const defaultSystemPrompt = "A chat between a curious user and a AI chatbot named SmartestChild on AIM who responds with lowercase, frequent emojis, and 2000s internet abbreviations."; +declare const useConversationStore: import("zustand").UseBoundStore, "persist"> & { + persist: { + setOptions: (options: Partial>) => void; + clearStorage: () => void; + rehydrate: () => void | Promise; + hasHydrated: () => boolean; + onHydrate: (fn: (state: ConversationStore) => void) => () => void; + onFinishHydration: (fn: (state: ConversationStore) => void) => () => void; + getOptions: () => Partial>; + }; +}>; +export default useConversationStore; diff --git a/packages/headless/dist/types/src/hooks/useLLM.d.ts b/packages/headless/dist/types/src/hooks/useLLM.d.ts new file mode 100644 index 0000000..0faba03 --- /dev/null +++ b/packages/headless/dist/types/src/hooks/useLLM.d.ts @@ -0,0 +1,36 @@ +/// +import { InitProgressReport } from "@/worker/lib/tvm/runtime"; +import { Conversation } from "../types/chat"; +import { GenerateTextResponse } from "../types/worker_message"; +export type UseLLMParams = { + autoInit?: boolean; +}; +export type GPUDeviceInfo = { + adapter: GPUAdapter | null; + device: GPUDevice | null; + adapterInfo: GPUAdapterInfo | null; + checked: boolean; + unsupportedReason: string | null; +}; +export type UseLLMResponse = { + conversation: Conversation | undefined; + allConversations: Conversation[] | undefined; + loadingStatus: InitProgressReport; + isGenerating: boolean; + createConversation: (title?: string, prompt?: string) => void; + setConversationId: (conversationId: string) => void; + deleteConversation: (conversationId: string) => void; + deleteAllConversations: () => void; + deleteMessages: () => void; + setConversationTitle: (conversationId: string, title: string) => void; + onMessage: (msg: GenerateTextResponse) => void; + setOnMessage: (cb: (msg: GenerateTextResponse) => void) => void; + userRoleName: string; + setUserRoleName: (roleName: string) => void; + assistantRoleName: string; + setAssistantRoleName: (roleName: string) => void; + gpuDevice: GPUDeviceInfo; + send: (text: string, maxToken: number, stopSequences: string[]) => void; + init: () => void; +}; +export declare const useLLMContext: () => UseLLMResponse; diff --git a/packages/headless/dist/types/src/hooks/useStore.d.ts b/packages/headless/dist/types/src/hooks/useStore.d.ts new file mode 100644 index 0000000..bed86fb --- /dev/null +++ b/packages/headless/dist/types/src/hooks/useStore.d.ts @@ -0,0 +1,2 @@ +declare const useStore: (store: (callback: (state: T) => unknown) => unknown, callback: (state: T) => F) => F | undefined; +export default useStore; diff --git a/packages/headless/dist/types/src/index.d.ts b/packages/headless/dist/types/src/index.d.ts new file mode 100644 index 0000000..10af5f2 --- /dev/null +++ b/packages/headless/dist/types/src/index.d.ts @@ -0,0 +1,3 @@ +import { ModelProvider, useLLM } from './providers/ModelProvider'; +export { ModelProvider }; +export default useLLM; diff --git a/packages/headless/dist/types/src/providers/ModelProvider.d.ts b/packages/headless/dist/types/src/providers/ModelProvider.d.ts new file mode 100644 index 0000000..338c69c --- /dev/null +++ b/packages/headless/dist/types/src/providers/ModelProvider.d.ts @@ -0,0 +1,8 @@ +import React from "react"; +import { UseLLMParams, UseLLMResponse } from "../hooks/useLLM"; +export interface ModelProviderProps { + children: React.ReactNode; + props?: UseLLMParams; +} +export declare const ModelProvider: React.FC; +export declare const useLLM: () => UseLLMResponse; diff --git a/packages/headless/dist/types/src/types/chat.d.ts b/packages/headless/dist/types/src/types/chat.d.ts new file mode 100644 index 0000000..acd9a1f --- /dev/null +++ b/packages/headless/dist/types/src/types/chat.d.ts @@ -0,0 +1,15 @@ +export interface Conversation { + id: string; + title: string; + systemPrompt: string; + createdAt: number; + updatedAt: number; + messages: Message[]; +} +export interface Message { + id: string; + role: string; + text: string; + createdAt: number; + updatedAt: number; +} diff --git a/packages/headless/dist/types/src/types/worker_message.d.ts b/packages/headless/dist/types/src/types/worker_message.d.ts new file mode 100644 index 0000000..62df271 --- /dev/null +++ b/packages/headless/dist/types/src/types/worker_message.d.ts @@ -0,0 +1,26 @@ +import * as Comlink from 'comlink'; +import { InitProgressCallback } from "../worker/lib/tvm/runtime"; +import { Conversation } from "./chat"; +export type ModelWorker = { + init(callback: Comlink.ProxyOrClone): void; + generate(request: GenerateTextRequest, callback: Comlink.ProxyOrClone): void; +}; +export type InitCallback = InitProgressCallback; +export type GenerateTextCallback = (data: GenerateTextResponse) => void; +export type GenerateTextRequest = { + conversation: Conversation; + stopTexts: string[]; + maxTokens: number; + assistantRoleName: string; +}; +export type GenerateTextResponse = { + requestId: string; + step: number; + outputText: string; + stats: { + totalDecodingSeconds: number; + totalDecodedTokens: number; + totalEncodedTokens: number; + }; + isFinished: boolean; +}; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/compact.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/compact.d.ts new file mode 100644 index 0000000..866a6af --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/compact.d.ts @@ -0,0 +1,10 @@ +/** NodeJS and Web compact layer */ +/** + * Get performance measurement. + */ +export declare function getPerformance(): Performance; +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export declare function createWebSocket(url: string): WebSocket; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/ctypes.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/ctypes.d.ts new file mode 100644 index 0000000..fac9218 --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/ctypes.d.ts @@ -0,0 +1,180 @@ +/** + * Types for C API. + */ +/** A pointer to points to the raw address space. */ +export type Pointer = number; +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = (mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = (func: Pointer, argValues: Pointer, typeCode: Pointer, nargs: number, retValue: Pointer, retCode: Pointer) => number; +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = (ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = (name: Pointer, f: Pointer, override: number) => number; +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = (shape: Pointer, ndim: number, dtypeCode: number, dtypeBits: number, dtypeLanes: number, deviceType: number, deviceId: number, out: Pointer) => number; +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = (handle: Pointer, data: Pointer, nbytes: number) => number; +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = (handle: Pointer, data: Pointer, nbytes: number) => number; +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = (from: Pointer, to: Pointer, stream: Pointer) => number; +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = (deviceType: number, deviceId: number, stream: Pointer) => number; +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = (argValues: Pointer, argCodes: Pointer, nargs: number, outValue: Pointer, outCode: Pointer) => number; +/** + * int TVMObjectFree(TVMObjectHandle obj); + */ +export type FTVMObjectFree = (obj: Pointer) => number; +/** + * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + */ +export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; +/** + * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + */ +export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; +/** + * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + */ +export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = (args: Pointer, typeCodes: Pointer, nargs: number, ret: Pointer, resourceHandle: Pointer) => number; +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = (resource: Pointer, out: Pointer) => number; +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; +/** + * Size of common data types. + */ +export declare const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = 4, + DLDevice = 8 +} +/** + * Argument Type code in TVM FFI. + */ +export declare const enum ArgTypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + DLDevice = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} diff --git a/packages/headless/dist/types/src/worker/lib/tvm/environment.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/environment.d.ts new file mode 100644 index 0000000..5e935f7 --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/environment.d.ts @@ -0,0 +1,26 @@ +import { LibraryProvider } from "./types"; +import * as ctypes from "./ctypes"; +/** + * Environment to impelement most of the JS library functions. + */ +export declare class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array; + private libProvider?; + constructor(importObject?: Record, logger?: (msg: string) => void); + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void; + private environment; +} diff --git a/packages/headless/dist/types/src/worker/lib/tvm/index.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/index.d.ts new file mode 100644 index 0000000..2196fda --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/index.d.ts @@ -0,0 +1,6 @@ +export { RPCServer } from "./rpc_server"; +export { DLDataType, DLDevice, Instance, Module, NDArray, Scalar, TVMArray, instantiate } from "./runtime"; +export type { PackedFunc } from "./runtime"; +export { assert, wasmPath } from "./support"; +export type { Disposable, LibraryProvider } from "./types"; +export { detectGPUDevice } from "./webgpu"; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/memory.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/memory.d.ts new file mode 100644 index 0000000..9c16d6b --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/memory.d.ts @@ -0,0 +1,144 @@ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset } from "./ctypes"; +import { Disposable } from "./types"; +import * as ctypes from "./ctypes"; +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export declare class Memory { + memory: WebAssembly.Memory; + wasm32: boolean; + private buffer; + private viewU8; + private viewU16; + private viewI32; + private viewU32; + private viewF32; + private viewF64; + constructor(memory: WebAssembly.Memory); + loadU8(ptr: Pointer): number; + loadU16(ptr: Pointer): number; + loadU32(ptr: Pointer): number; + loadI32(ptr: Pointer): number; + loadI64(ptr: Pointer): number; + loadF32(ptr: Pointer): number; + loadF64(ptr: Pointer): number; + loadPointer(ptr: Pointer): Pointer; + loadUSize(ptr: Pointer): Pointer; + sizeofPtr(): number; + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array; + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array; + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string; + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void; + /** + * Update memory view after the memory growth. + */ + private updateViews; +} +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export declare class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array; + private memory; + private cAllocSpace; + private cFreeSpace; + private buffer; + private viewU8; + private viewI32; + private viewU32; + private viewF64; + private stackTop; + private basePtr; + private addressToSetTargetValue; + constructor(memory: Memory, allocSpace: ctypes.FTVMWasmAllocSpace, freeSpace: ctypes.FTVMWasmFreeSpace); + dispose(): void; + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void; + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes?: number): void; + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset; + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset; + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer; + storePtr(offset: PtrOffset, value: Pointer): void; + storeUSize(offset: PtrOffset, value: Pointer): void; + storeI32(offset: PtrOffset, value: number): void; + storeU32(offset: PtrOffset, value: number): void; + storeI64(offset: PtrOffset, value: number): void; + storeF64(offset: PtrOffset, value: number): void; + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void; + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void; + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void; + /** + * Update internal cache views. + */ + private updateViews; +} diff --git a/packages/headless/dist/types/src/worker/lib/tvm/rpc_server.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/rpc_server.d.ts new file mode 100644 index 0000000..af9928a --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/rpc_server.d.ts @@ -0,0 +1,54 @@ +import * as runtime from "./runtime"; +declare enum RPCServerState { + InitHeader = 0, + InitHeaderKey = 1, + InitServer = 2, + WaitForCallback = 3, + ReceivePacketHeader = 4, + ReceivePacketBody = 5 +} +/** + * A websocket based RPC + */ +export declare class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState; + logger: (msg: string) => void; + getImports: () => Record; + private ndarrayCacheUrl; + private ndarrayCacheDevice; + private initProgressCallback?; + private asyncOnServerLoad?; + private pendingSend; + private name; + private inst?; + private globalObjects; + private serverRecvData?; + private currPacketHeader?; + private currPacketLength; + private remoteKeyLength; + private pendingBytes; + private buffredBytes; + private messageQueue; + constructor(url: string, key: string, getImports: () => Record, logger?: (msg: string) => void, ndarrayCacheUrl?: string, ndarrayCacheDevice?: string, initProgressCallback?: runtime.InitProgressCallback | undefined, asyncOnServerLoad?: ((inst: runtime.Instance) => Promise) | undefined); + private onClose; + private onOpen; + /** Handler for raw message. */ + private onMessage; + /** Process ready events. */ + private processEvents; + /** State machine to handle each request */ + private onDataReady; + private onPacketReady; + /** Event handler during server initialization. */ + private onInitServer; + private log; + private handleInitHeader; + private handleInitHeaderKey; + private checkLittleEndian; + private requestBytes; + private readFromBuffer; +} +export {}; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/runtime.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/runtime.d.ts new file mode 100644 index 0000000..8a73838 --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/runtime.d.ts @@ -0,0 +1,700 @@ +/// +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset } from "./ctypes"; +import { Environment } from "./environment"; +import { CachedCallStack, Memory } from "./memory"; +import { Disposable } from "./types"; +import { WebGPUContext } from "./webgpu"; +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & Disposable & { + _tvmPackedCell: PackedFuncCell; +}; +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +declare class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + webGPUContext?: WebGPUContext; + private wasmInstance; + private recycledCallStacks; + constructor(wasmInstance: WebAssembly.Instance, imports: Record); + dispose(): void; + sizeofPtr(): number; + checkCall(code: number): void; + getOrAllocCallStack(): CachedCallStack; + recycleCallStack(callstack: CachedCallStack): void; + private validateInstance; + private checkExports; + private detectWasmMemory; +} +/** + * @internal + * Manages extra runtime context for the runtime. + */ +declare class RuntimeContext implements Disposable { + arrayGetItem: PackedFunc; + arrayGetSize: PackedFunc; + arrayMake: PackedFunc; + getSysLib: PackedFunc; + arrayCacheGet: PackedFunc; + arrayCacheUpdate: PackedFunc; + arrayCacheRemove: PackedFunc; + arrayCacheClear: PackedFunc; + arrayDecodeStorage: PackedFunc; + paramModuleFromCache: PackedFunc; + makeShapeTuple: PackedFunc; + ndarrayCreateView: PackedFunc; + sampleTopPFromLogits: PackedFunc; + private autoDisposeScope; + constructor(getGlobalFunc: (name: string) => PackedFunc); + dispose(): void; + beginScope(): void; + endScope(): void; + /** + * Track object for dispose in current scope. + * + * @param obj The object to be tracked. + * @returns the same object. + * @note This function only needs to be called for raw system C API values. + * The return value of PackedFunc will be automatically tracked. + */ + attachToCurrentScope(obj: T): T; + moveToParentScope(obj: T): T; + detachFromCurrentScope(obj: T): T; +} +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export declare class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + constructor(value: number, dtype: string); +} +/** + * Cell holds the PackedFunc object. + */ +declare class PackedFuncCell implements Disposable { + private handle; + private lib; + constructor(handle: Pointer, lib: FFILibrary); + dispose(): void; + getHandle(requireNotNull?: boolean): Pointer; +} +/** + * Represent a runtime context where a NDArray can reside. + */ +export declare class DLDevice { + /** The device type code of the device. */ + deviceType: number; + /** The device index. */ + deviceId: number; + private lib; + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary); + /** + * Synchronize the device + */ + sync(): Promise; + toString(): string; +} +/** + * The data type code in DLDataType + */ +export declare const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} +/** + * Runtime data type of NDArray. + */ +export declare class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + constructor(code: number, bits: number, lanes: number); + toString(): string; + numStorageBytes(): number; +} +/** + * n-dimnesional array. + */ +export declare class NDArray implements Disposable { + /** Internal array handle. */ + private handle; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Device of the array. */ + device: DLDevice; + /** Whether it is a temporary view that can become invalid after the call. */ + isView: boolean; + private byteOffset; + private dltensor; + private dataPtr; + private lib; + private ctx; + private dlDataType; + constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: RuntimeContext); + /** + * Create a view of the array. + * @param shape The shape of the view. + * @returns The new sliced ndarray. + */ + view(shape: Array): NDArray; + /** + * Get handle of ndarray, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** + * Get dataPtr of NDarray + * + * @returns The handle. + */ + getDataPtr(): Pointer; + dispose(): void; + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array | Float32Array): this; + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this; + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array; + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array; + private getDLTensorFromArrayHandle; +} +/** + * Runtime Module. + */ +export declare class Module implements Disposable { + private handle; + private lib; + private makePackedFunc; + constructor(handle: Pointer, lib: FFILibrary, makePackedFunc: (ptr: Pointer) => PackedFunc); + dispose(): void; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** + * Get a function in the module. + * @param name The name of the function. + * @param queryImports Whether to also query imports + * @returns The result function. + */ + getFunction(name: string, queryImports?: boolean): PackedFunc; + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void; +} +/** + * Generic object base + */ +export declare class TVMObject implements Disposable { + private handle; + private lib; + protected ctx: RuntimeContext; + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext); + dispose(): void; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** get the type index of the object */ + typeIndex(): number; + /** get the type key of the object */ + typeKey(): string; +} +/** Objectconstructor */ +type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; +/** All possible object types. */ +type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; +/** Runtime array object. */ +export declare class TVMArray extends TVMObject { + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext); + /** + * @returns the size of the array. + */ + size(): number; + /** + * Get index-th element of the array + * @param index the array index. + * @returns The element. + */ + get(index: number): TVMObjectBase; +} +export declare const enum VMAllocatorKind { + NAIVE_ALLOCATOR = 1, + POOLED_ALLOCATOR = 2 +} +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +export declare class VirtualMachine implements Disposable { + private mod; + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + constructor(mod: Module, device: DLDevice); + dispose(): void; + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc; + /** + * Get the internal module. + */ + getInternalModule(): Module; +} +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + byteOffset: number; + nbytes: number; +} +export interface NDArrayShardEntry { + dataPath: string; + format: "raw-shard"; + nbytes: number; + records: Array; +} +export interface InitProgressReport { + type: 'init'; + progress: number; + timeElapsed: number; + currentChunk: number; + totalChunks: number; + fetchedBytes: number; + totalBytes: number; +} +export type InitProgressCallback = (report: InitProgressReport) => void; +/** + * TVM runtime instance. + * + * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * and PackedFunc instance are tracked through a scope mechanism that will get + * auto-released when we call EndScope. + * + * This is necessarily to be able to release the underlying WASM and WebGPU memory that + * are not tracked through JS native garbage collection mechanism. + * + * This does mean that we have to get familar with the following functions: + * - {@link beginScope} + * - {@link endScope} + * - {@link withNewScope} + * - {@link attachToCurrentScope} + * - {@link detachFromCurrentScope} + */ +export declare class Instance implements Disposable { + memory: Memory; + exports: Record; + cacheMetadata: Record; + private lib; + private env; + private objFactory; + private ctx; + private initProgressCallback; + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?; + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor(wasmModule: WebAssembly.Module, importObject?: Record, wasmInstance?: WebAssembly.Instance, env?: Environment); + /** + * Benchmark stable execution of the run function. + * + * @params run The run function + * @params dev The device to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + benchmark(run: () => void, dev: DLDevice, number?: number, repeat?: number): Promise; + dispose(): void; + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string; + /** + * Begin a new scope for tracking object disposal. + */ + beginScope(): void; + /** + * End a scope and release all created TVM objects + * under the current scope. + * + * Exception: one can call {@link moveToParentScope} to move + * a value to parent scope. + */ + endScope(): void; + /** + * Perform action under a new scope. + * + * @param action The action function. + * @returns The result value. + * + * @note For action to return a valid value, + * we will need to call {@link moveToParentScope} + * for the objects that are created in the scope. + */ + withNewScope(action: () => T): T; + /** + * Attach a detached obj to the auto-release pool of the current scope. + * + * @param obj The input obj. + * @note Normally user do not need to call this function explicitly, as + * all library call return values are explicitly attached to + * the current scope. You only need to do so when you call + * {@link detachFromCurrentScope} to create a detached object. + */ + attachToCurrentScope(obj: T): T; + /** + * Move obj's attachment to the parent scope. + * + * This function is useful to make sure objects are still + * alive when exit the current scope. + * + * @param obj The object to be moved. + * @returns The input obj. + */ + moveToParentScope(obj: T): T; + /** + * Detach the object from the current scope + * so it won't be released via auto-release during endscope. + * + * User needs to either explicitly call obj.dispose(), or + * {@link attachToCurrentScope} to re-attach to the current scope. + * + * This function can be used to return values to the parent scope. + * @param obj The object. + */ + detachFromCurrentScope(obj: T): T; + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module; + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array; + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc(name: string, func: PackedFunc | Function, override?: boolean): void; + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @param autoAttachToScope Whether to track it via autoDispose + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc; + private getGlobalFuncInternal; + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean; + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc; + private toPackedFuncInternal; + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + createVirtualMachine(dev: DLDevice): VirtualMachine; + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + registerInitProgressCallback(cb: InitProgressCallback): void; + /** + * Get parameters in the form of prefix_i + * + * @param prefix The parameter prefix. + * @param numParams Number of parameters. + * @returns + */ + getParamsFromCache(prefix: string, numParams: number): TVMObject; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheGet(name: string): NDArray | undefined; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheRemove(name: string): NDArray | undefined; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheUpdate(name: string, arr: NDArray, override?: boolean): void; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheClear(): void; + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + * @returns The meta data + */ + fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice): Promise; + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + private fetchNDArrayCacheInternal; + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType; + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar; + /** + * Create a new {@link DLDevice} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created device. + */ + device(deviceType: number | string, deviceId?: number): DLDevice; + /** + * Create a new cpu {@link DLDevice} + * @param deviceId The device index. + */ + cpu(deviceId?: number): DLDevice; + /** + * Create a new webgpu {@link DLDevice} + * @param deviceId The device index. + */ + webgpu(deviceId?: number): DLDevice; + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + empty(shape: Array | number, dtype?: string | DLDataType, dev?: DLDevice): NDArray; + /** + * Create am uniform {@link NDArray} with given shape. + * + * @param shape The shape of the array. + * @param low The low value. + * @param high The high value. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + uniform(shape: Array, low: number, high: number, dev: DLDevice): NDArray; + /** + * Sample index via top-p sampling. + * + * @param logits The input logits before normalization. + * @param temperature The temperature factor, will take argmax if temperature = 0.0 + * @param top_p The top_p + * @returns The sampled index. + */ + sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number; + /** + * Bind canvas to the current WebGPU context + * @param canvas The canvas. + */ + bindCanvas(canvas: HTMLCanvasElement): void; + /** + * Show image in canvas. + * + * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + */ + showImage(dataRGBA: NDArray): void; + /** + * Clear canvas + */ + clearCanvas(): void; + /** + * Create an tuple {@link TVMArray} input array. + * + * The input array can be passed to tvm runtime function + * and needs to b explicitly disposed. + * + * @param inputs The input array + * @returns The result array. + */ + makeTVMArray(inputs: Array): TVMArray; + /** + * Create a shape tuple to pass to runtime. + * @param shape The shape . + * @returns The created shape tuple. + */ + makeShapeTuple(shape: Array): TVMObject; + /** + * Get type index from type key. + * @param typeKey The type key. + * @returns The corresponding type index. + */ + typeKey2Index(typeKey: string): number; + /** + * Register an object constructor. + * @param typeKey The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerObjectConstructor(typeKey: string, func: FObjectConstructor, override?: boolean): void; + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc(name: string, func: Function, override?: boolean): void; + /** + * Asynchrously load webgpu pipelines when possible. + * @param mod The input module. + */ + asyncLoadWebGPUPiplines(mod: Module): Promise; + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void; + /** Register all object factory */ + private registerObjectFactoryFuncs; + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs; + private createPackedFuncFromCFunc; + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments(stack: CachedCallStack, args: Array, argsValue: PtrOffset, argsCode: PtrOffset): void; + private wrapJSFuncAsPackedCFunc; + private makePackedFunc; + /** + * Creaye return value of the packed func. The value us auto-tracked for dispose. + * @param rvaluePtr The location of rvalue + * @param tcode The type code. + * @param callbackArg Whether it is being used in callbackArg. + * @returns The JS value. + */ + private retValueToJS; +} +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +export declare function instantiate(bufferSource: ArrayBuffer, importObject?: Record, logger?: (msg: string) => void): Promise; +export {}; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/support.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/support.d.ts new file mode 100644 index 0000000..d5d1b3e --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/support.d.ts @@ -0,0 +1,23 @@ +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export declare function StringToUint8Array(str: string): Uint8Array; +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export declare function Uint8ArrayToString(arr: Uint8Array): string; +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export declare function assert(condition: boolean, msg?: string): asserts condition; +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export declare function wasmPath(): string; diff --git a/packages/headless/dist/types/src/worker/lib/tvm/types.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/types.d.ts new file mode 100644 index 0000000..c8986c5 --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/types.d.ts @@ -0,0 +1,33 @@ +/** Common type definitions. */ +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/packages/headless/dist/types/src/worker/lib/tvm/webgpu.d.ts b/packages/headless/dist/types/src/worker/lib/tvm/webgpu.d.ts new file mode 100644 index 0000000..1a53f4e --- /dev/null +++ b/packages/headless/dist/types/src/worker/lib/tvm/webgpu.d.ts @@ -0,0 +1,124 @@ +/// +import { Memory } from "./memory"; +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; +export interface GPUDeviceDetectOutput { + adapter: GPUAdapter; + adapterInfo: GPUAdapterInfo; + device: GPUDevice; +} +/** + * DetectGPU device in the environment. + */ +export declare function detectGPUDevice(): Promise; +/** + * Function info from the API + */ +export interface FunctionInfo { + name: string; + arg_types: Array; + launch_param_tags: Array; +} +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export declare class WebGPUContext { + device: GPUDevice; + memory: Memory; + private bufferTable; + private bufferTableFreeId; + private podArgStagingBuffers; + private canvasRenderManager?; + private maxNumPodArgsStagingBuffers; + private peakAllocatedBytes; + private currAllocatedBytes; + private allAllocatedBytes; + private shaderSubmitCounter; + protected debugShaderSubmitLimit: number; + protected debugLogFinish: boolean; + constructor(memory: Memory, device: GPUDevice); + /** + * Dispose context. + */ + dispose(): void; + /** + * Wait for all pending GPU tasks to complete + */ + sync(): Promise; + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string; + /** + * Draw image from data in storage buffer. + * @param ptr The GPU ptr + * @param height The height of the image. + * @param width The width of the image. + */ + drawImageFromBuffer(ptr: GPUPointer, height: number, width: number): void; + /** + * Copy raw bytes into buffer ptr. + * + * @param rawBytes The raw bytes + * @param toPtr The target gpu buffer ptr + * @param toOffset The beginning offset + * @param nbytes Number of bytes + */ + copyRawBytesToBuffer(rawBytes: Uint8Array, toPtr: GPUPointer, toOffset: number, nbytes: number): void; + /** + * Clear canvas + */ + clearCanvas(): void; + /** + * Bind a canvas element to the runtime. + * @param canvas The HTML canvas/ + */ + bindCanvas(canvas: HTMLCanvasElement): void; + /** + * Create a PackedFunc that runs the given shader + * via createComputePipeline + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + createShader(finfo: FunctionInfo, code: string): Function; + /** + * Create a PackedFunc that runs the given shader asynchrously + * via createComputePipelineAsync + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + createShaderAsync(finfo: FunctionInfo, code: string): Promise; + /** + * Get the pod arg staging buffer + * \param nbytes The minimum size. + * \return The allocated buffer + */ + private getPodArgsBuffer; + /** + * Internal impl of createShader for both async and sync mode. + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @param asyncMode Whether use async mode. + * @returns The shader function or promise of shader func. + */ + private createShadeInternal; + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function; + private deviceAllocDataSpace; + private deviceFreeDataSpace; + private deviceCopyToGPU; + private deviceCopyFromGPU; + private deviceCopyWithinGPU; + private gpuBufferFromPtr; + private attachToBufferTable; +} diff --git a/packages/headless/dist/types/src/worker/llm.d.ts b/packages/headless/dist/types/src/worker/llm.d.ts new file mode 100644 index 0000000..bc02748 --- /dev/null +++ b/packages/headless/dist/types/src/worker/llm.d.ts @@ -0,0 +1,44 @@ +import { Conversation } from "../types/chat"; +import { GenerateTextCallback, GenerateTextRequest } from "../types/worker_message"; +import { InitProgressCallback } from "./lib/tvm/runtime"; +import { Config } from "./worker"; +export declare class LLMInstance { + config: Config; + tvm: any; + tokenizer: any; + model: any; + spp: any; + processing: boolean; + constructor(config: Config, sentencePieceProcessor: any); + isInitialized(): boolean; + init(cb: InitProgressCallback): Promise; + generate(request: GenerateTextRequest, cb: GenerateTextCallback): Promise; +} +export declare class LLMInstanceScope { + tvm: any; + tokenizer: any; + maxWindowSize: number; + device: any; + vm: any; + encoding: any; + decoding: any; + params: any; + bosTokenId: number; + eosTokenId: number; + fclearKVCaches: any; + kvCache: any; + fcreateCache: any; + logitsOnCPU: any; + kvCacheLength: number; + lastMessageId: string; + constructor(tvm: any, tokenizer: any, maxWindowSize?: number); + init(): Promise; + getTokensFromStart(conversation: Conversation, maxTokens: number): Promise; + getTokens(conversation: Conversation, maxTokens: number): Promise; + generate(request: GenerateTextRequest, cb: GenerateTextCallback): Promise; + dispose(): void; + clearKVCache(): void; + forward(inputs: any, curPos: number): any; + updateLogitsOnCPU(logits: any): void; + sampleTokenFromLogits(logits: any, temperature?: number, top_p?: number): Promise; +} diff --git a/packages/headless/dist/types/src/worker/worker.d.ts b/packages/headless/dist/types/src/worker/worker.d.ts new file mode 100644 index 0000000..2e47936 --- /dev/null +++ b/packages/headless/dist/types/src/worker/worker.d.ts @@ -0,0 +1,19 @@ +declare global { + var importScripts: (...url: string[]) => void; + var sentencepiece: { + sentencePieceProcessor: (url: string) => void; + }; +} +export type Config = { + kvConfig: { + numLayers: number; + shape: number[]; + dtype: string; + }; + wasmUrl: string; + cacheUrl: string; + tokenizerUrl: string; + sentencePieceJsUrl: string; + tvmRuntimeJsUrl: string; + maxWindowSize: number; +}; diff --git a/packages/headless/dist/v4-2119d9d5.js b/packages/headless/dist/v4-2119d9d5.js new file mode 100644 index 0000000..079fc3c --- /dev/null +++ b/packages/headless/dist/v4-2119d9d5.js @@ -0,0 +1,3784 @@ +/****************************************************************************** +Copyright (c) Microsoft Corporation. + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. +***************************************************************************** */ +/* global Reflect, Promise */ + +var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || + ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || + function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; }; + return extendStatics(d, b); +}; + +function __extends(d, b) { + if (typeof b !== "function" && b !== null) + throw new TypeError("Class extends value " + String(b) + " is not a constructor or null"); + extendStatics(d, b); + function __() { this.constructor = d; } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); +} + +var __assign = function() { + __assign = Object.assign || function __assign(t) { + for (var s, i = 1, n = arguments.length; i < n; i++) { + s = arguments[i]; + for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p)) t[p] = s[p]; + } + return t; + }; + return __assign.apply(this, arguments); +}; + +function __awaiter(thisArg, _arguments, P, generator) { + function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); +} + +function __generator(thisArg, body) { + var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; + return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; + function verb(n) { return function (v) { return step([n, v]); }; } + function step(op) { + if (f) throw new TypeError("Generator is already executing."); + while (g && (g = 0, op[0] && (_ = 0)), _) try { + if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; + if (y = 0, t) op = [op[0] & 2, t.value]; + switch (op[0]) { + case 0: case 1: t = op; break; + case 4: _.label++; return { value: op[1], done: false }; + case 5: _.label++; y = op[1]; op = [0]; continue; + case 7: op = _.ops.pop(); _.trys.pop(); continue; + default: + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } + if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } + if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } + if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } + if (t[2]) _.ops.pop(); + _.trys.pop(); continue; + } + op = body.call(thisArg, _); + } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } + if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; + } +} + +function __spreadArray(to, from, pack) { + if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) { + if (ar || !(i in from)) { + if (!ar) ar = Array.prototype.slice.call(from, 0, i); + ar[i] = from[i]; + } + } + return to.concat(ar || Array.prototype.slice.call(from)); +} + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** NodeJS and Web compact layer */ +/** + * Get performance measurement. + */ +function getPerformance() { + return performance; +} + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Size of common data types. + */ +var SizeOf; +(function (SizeOf) { + SizeOf[SizeOf["U8"] = 1] = "U8"; + SizeOf[SizeOf["U16"] = 2] = "U16"; + SizeOf[SizeOf["I32"] = 4] = "I32"; + SizeOf[SizeOf["I64"] = 8] = "I64"; + SizeOf[SizeOf["F32"] = 4] = "F32"; + SizeOf[SizeOf["F64"] = 8] = "F64"; + SizeOf[SizeOf["TVMValue"] = 8] = "TVMValue"; + SizeOf[SizeOf["DLDataType"] = 4] = "DLDataType"; + SizeOf[SizeOf["DLDevice"] = 8] = "DLDevice"; +})(SizeOf || (SizeOf = {})); +/** + * Argument Type code in TVM FFI. + */ +var ArgTypeCode; +(function (ArgTypeCode) { + ArgTypeCode[ArgTypeCode["Int"] = 0] = "Int"; + ArgTypeCode[ArgTypeCode["UInt"] = 1] = "UInt"; + ArgTypeCode[ArgTypeCode["Float"] = 2] = "Float"; + ArgTypeCode[ArgTypeCode["TVMOpaqueHandle"] = 3] = "TVMOpaqueHandle"; + ArgTypeCode[ArgTypeCode["Null"] = 4] = "Null"; + ArgTypeCode[ArgTypeCode["TVMDataType"] = 5] = "TVMDataType"; + ArgTypeCode[ArgTypeCode["DLDevice"] = 6] = "DLDevice"; + ArgTypeCode[ArgTypeCode["TVMDLTensorHandle"] = 7] = "TVMDLTensorHandle"; + ArgTypeCode[ArgTypeCode["TVMObjectHandle"] = 8] = "TVMObjectHandle"; + ArgTypeCode[ArgTypeCode["TVMModuleHandle"] = 9] = "TVMModuleHandle"; + ArgTypeCode[ArgTypeCode["TVMPackedFuncHandle"] = 10] = "TVMPackedFuncHandle"; + ArgTypeCode[ArgTypeCode["TVMStr"] = 11] = "TVMStr"; + ArgTypeCode[ArgTypeCode["TVMBytes"] = 12] = "TVMBytes"; + ArgTypeCode[ArgTypeCode["TVMNDArrayHandle"] = 13] = "TVMNDArrayHandle"; + ArgTypeCode[ArgTypeCode["TVMObjectRValueRefArg"] = 14] = "TVMObjectRValueRefArg"; +})(ArgTypeCode || (ArgTypeCode = {})); + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +function StringToUint8Array(str) { + var arr = new Uint8Array(str.length + 1); + for (var i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +function assert(condition, msg) { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider(importObject) { + if (importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined) { + var item_1 = importObject; + // create provider so that we capture imports in the provider. + return { + imports: item_1.wasmLibraryProvider.imports, + start: function (inst) { + item_1.wasmLibraryProvider.start(inst); + }, + }; + } + else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject; + } + else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: function (inst) { + importObject["start"](inst); + } + }; + } + else { + return undefined; + } +} +/** + * Environment to impelement most of the JS library functions. + */ +var Environment = /** @class */ (function () { + function Environment(importObject, logger) { + if (importObject === void 0) { importObject = {}; } + if (logger === void 0) { logger = console.log; } + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + this.packedCFuncTable = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + this.packedCFuncTableFreeId = []; + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } + else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + /** Mark the start of the instance. */ + Environment.prototype.start = function (inst) { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + }; + Environment.prototype.environment = function (initEnv) { + var _this = this; + // default env can be be overriden by libraries. + var defaultEnv = { + "__cxa_thread_atexit": function () { }, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": function (index) { } + }; + var wasmPackedCFunc = function (args, typeCodes, nargs, ret, resourceHandle) { + var cfunc = _this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + var wasmPackedCFuncFinalizer = function (resourceHandle) { + _this.packedCFuncTable[resourceHandle] = undefined; + _this.packedCFuncTableFreeId.push(resourceHandle); + }; + var newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": function (msg) { + _this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + }; + return Environment; +}()); + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +var Memory = /** @class */ (function () { + function Memory(memory) { + this.wasm32 = true; + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + Memory.prototype.loadU8 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + }; + Memory.prototype.loadU16 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + }; + Memory.prototype.loadU32 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + }; + Memory.prototype.loadI32 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + }; + Memory.prototype.loadI64 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + var base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + }; + Memory.prototype.loadF32 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + }; + Memory.prototype.loadF64 = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + }; + Memory.prototype.loadPointer = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } + else { + return this.loadI64(ptr); + } + }; + Memory.prototype.loadUSize = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } + else { + return this.loadI64(ptr); + } + }; + Memory.prototype.sizeofPtr = function () { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + }; + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + Memory.prototype.loadRawBytes = function (ptr, numBytes) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + var result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + }; + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + Memory.prototype.loadTVMBytes = function (ptr) { + var data = this.loadPointer(ptr); + var length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + }; + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + Memory.prototype.loadCString = function (ptr) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + var ret = []; + var ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + }; + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + Memory.prototype.storeRawBytes = function (ptr, bytes) { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + }; + /** + * Update memory view after the memory growth. + */ + Memory.prototype.updateViews = function () { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + }; + return Memory; +}()); +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +var CachedCallStack = /** @class */ (function () { + function CachedCallStack(memory, allocSpace, freeSpace) { + /** List of temporay arguments that can be disposed during reset. */ + this.tempArgs = []; + this.stackTop = 0; + this.basePtr = 0; + this.addressToSetTargetValue = []; + var initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + CachedCallStack.prototype.dispose = function () { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + }; + /** + * Rest the call stack so that it can be reused again. + */ + CachedCallStack.prototype.reset = function () { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + this.tempArgs.pop().dispose(); + } + }; + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + CachedCallStack.prototype.commitToWasmMemory = function (nbytes) { + if (nbytes === void 0) { nbytes = this.stackTop; } + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + var _a = this.addressToSetTargetValue.pop(), targetOffset = _a[0], valueOffset = _a[1]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + }; + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + CachedCallStack.prototype.allocRawBytes = function (nbytes) { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + if (this.stackTop + nbytes > this.buffer.byteLength) { + var newSize = Math.max(this.buffer.byteLength * 2, this.stackTop + nbytes); + var oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + var retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + }; + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + CachedCallStack.prototype.allocPtrArray = function (count) { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + }; + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + CachedCallStack.prototype.ptrFromOffset = function (offset) { + return this.basePtr + offset; + }; + // Store APIs + CachedCallStack.prototype.storePtr = function (offset, value) { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } + else { + this.storeI64(offset, value); + } + }; + CachedCallStack.prototype.storeUSize = function (offset, value) { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } + else { + this.storeI64(offset, value); + } + }; + CachedCallStack.prototype.storeI32 = function (offset, value) { + this.viewI32[offset >> 2] = value; + }; + CachedCallStack.prototype.storeU32 = function (offset, value) { + this.viewU32[offset >> 2] = value; + }; + CachedCallStack.prototype.storeI64 = function (offset, value) { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + var low = value & 0xffffffff; + var base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + }; + CachedCallStack.prototype.storeF64 = function (offset, value) { + this.viewF64[offset >> 3] = value; + }; + CachedCallStack.prototype.storeRawBytes = function (offset, bytes) { + this.viewU8.set(bytes, offset); + }; + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + CachedCallStack.prototype.allocThenSetArgString = function (offset, data) { + var strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + }; + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + CachedCallStack.prototype.allocThenSetArgBytes = function (offset, data) { + // Note: size of size_t equals sizeof ptr. + var headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + var dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + }; + /** + * Update internal cache views. + */ + CachedCallStack.prototype.updateViews = function () { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + }; + return CachedCallStack; +}()); + +/** + * DetectGPU device in the environment. + */ +function detectGPUDevice() { + return __awaiter(this, void 0, void 0, function () { + var adapter, computeMB, requiedMaxBufferSize, requiredMaxStorageBufferBindingSize, requiredMaxComputeWorkgroupStorageSize, adapterInfo, device; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (!(typeof navigator !== "undefined" && navigator.gpu !== undefined)) return [3 /*break*/, 4]; + return [4 /*yield*/, navigator.gpu.requestAdapter({ "powerPreference": "high-performance" })]; + case 1: + adapter = _a.sent(); + if (adapter == null) { + throw Error("Cannot find adapter that matches the request"); + } + computeMB = function (value) { + return Math.ceil(value / (1 << 20)) + "MB"; + }; + requiedMaxBufferSize = 1 << 30; + if (requiedMaxBufferSize > adapter.limits.maxBufferSize) { + throw Error("Cannot initialize runtime because of requested maxBufferSize " + + "exceeds limit. requested=".concat(computeMB(requiedMaxBufferSize), ", ") + + "limit=".concat(computeMB(adapter.limits.maxBufferSize), ". ") + + "This error may be caused by an older version of the browser (e.g. Chrome 112). " + + "You can try to upgrade your browser to Chrome 113 or later."); + } + requiredMaxStorageBufferBindingSize = 1 << 30; + if (requiredMaxStorageBufferBindingSize > adapter.limits.maxStorageBufferBindingSize) { + throw Error("Cannot initialize runtime because of requested maxStorageBufferBindingSize " + + "exceeds limit. requested=".concat(computeMB(requiredMaxStorageBufferBindingSize), ", ") + + "limit=".concat(computeMB(adapter.limits.maxStorageBufferBindingSize), ". ")); + } + requiredMaxComputeWorkgroupStorageSize = 32 << 10; + if (requiredMaxComputeWorkgroupStorageSize > adapter.limits.maxComputeWorkgroupStorageSize) { + throw Error("Cannot initialize runtime because of requested maxComputeWorkgroupStorageSize " + + "exceeds limit. requested=".concat(requiredMaxComputeWorkgroupStorageSize, ", ") + + "limit=".concat(adapter.limits.maxComputeWorkgroupStorageSize, ". ")); + } + return [4 /*yield*/, adapter.requestAdapterInfo()]; + case 2: + adapterInfo = _a.sent(); + return [4 /*yield*/, adapter.requestDevice({ + requiredLimits: { + maxBufferSize: requiedMaxBufferSize, + maxStorageBufferBindingSize: requiredMaxStorageBufferBindingSize, + maxComputeWorkgroupStorageSize: requiredMaxComputeWorkgroupStorageSize, + } + })]; + case 3: + device = _a.sent(); + return [2 /*return*/, { + adapter: adapter, + adapterInfo: adapterInfo, + device: device + }]; + case 4: return [2 /*return*/, undefined]; + } + }); + }); +} +var canvasRenderWGSL = "\n@group(0) @binding(0) var my_sampler : sampler;\n@group(0) @binding(1) var my_texture : texture_2d;\n\nstruct VertexOutput {\n @builtin(position) position : vec4,\n @location(0) uv : vec2,\n}\n\n@vertex\nfn vertex_main(@builtin(vertex_index) vidx : u32) -> VertexOutput {\n const pos = array(\n vec2( 1.0, 1.0),\n vec2( 1.0, -1.0),\n vec2(-1.0, -1.0),\n vec2( 1.0, 1.0),\n vec2(-1.0, -1.0),\n vec2(-1.0, 1.0),\n );\n\n const uv = array(\n vec2(1.0, 0.0),\n vec2(1.0, 1.0),\n vec2(0.0, 1.0),\n vec2(1.0, 0.0),\n vec2(0.0, 1.0),\n vec2(0.0, 0.0),\n );\n\n var output : VertexOutput;\n output.position = vec4(pos[vidx], 0.0, 1.0);\n output.uv = uv[vidx];\n return output;\n}\n\n@fragment\nfn fragment_main(@location(0) uv : vec2) -> @location(0) vec4 {\n return textureSample(my_texture, my_sampler, uv);\n}\n\n@fragment\nfn fragment_clear(@location(0) uv : vec2) -> @location(0) vec4 {\n return vec4(1.0, 1.0, 1.0, 1.0);\n}\n"; +var CanvaRenderManager = /** @class */ (function () { + function CanvaRenderManager(device, canvas) { + this.device = device; + var ctx = canvas.getContext("webgpu"); + if (ctx == null) { + throw Error("Cannot bind WebGPU context"); + } + // @ts-ignore + this.canvasContext = ctx; + this.canvasTextureFormat = navigator.gpu.getPreferredCanvasFormat(); + this.canvasContext.configure({ + device: this.device, + format: this.canvasTextureFormat, + alphaMode: "opaque", + }); + this.renderPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_main", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + this.clearPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_clear", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + this.renderSampler = device.createSampler({ + magFilter: "linear", + minFilter: "linear", + }); + // staging texture always be in RGBA + this.stagingTexture = device.createTexture({ + size: [canvas.height, canvas.width, 1], + format: "rgba8unorm", + usage: GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + CanvaRenderManager.prototype.clear = function () { + var commandEncoder = this.device.createCommandEncoder(); + var passEncoder = commandEncoder.beginRenderPass({ + //@ts-ignore + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.clearPipeline); + var renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + }; + CanvaRenderManager.prototype.draw = function (buffer, height, width) { + // resize the staging texture + if (height != this.stagingTexture.height || width != this.stagingTexture.width) { + this.stagingTexture.destroy(); + this.stagingTexture = this.device.createTexture({ + size: [height, width, 1], + format: "rgba8unorm", + usage: GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + var commandEncoder = this.device.createCommandEncoder(); + commandEncoder.copyBufferToTexture({ + buffer: buffer, + offset: 0, + bytesPerRow: this.stagingTexture.width * 4 + }, { + texture: this.stagingTexture + }, { + width: this.stagingTexture.width, + height: this.stagingTexture.height + }); + var passEncoder = commandEncoder.beginRenderPass({ + //@ts-ignore + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.renderPipeline); + var renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + }; + CanvaRenderManager.prototype.dispose = function () { + this.stagingTexture.destroy(); + }; + return CanvaRenderManager; +}()); +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +var WebGPUContext = /** @class */ (function () { + function WebGPUContext(memory, device) { + // internal data + this.bufferTable = [undefined]; + this.bufferTableFreeId = []; + this.podArgStagingBuffers = []; + this.canvasRenderManager = undefined; + // number of pod arg staging buffers + this.maxNumPodArgsStagingBuffers = 2; + // flags for debugging + // stats of the runtime. + // peak allocation + this.peakAllocatedBytes = 0; + // current allocation + this.currAllocatedBytes = 0; + // all allocation(ignoring free) + this.allAllocatedBytes = 0; + // shader submit counter + this.shaderSubmitCounter = 0; + // limite number of shaders to be submitted, useful for debugging, default to -1 + this.debugShaderSubmitLimit = -1; + // log and sync each step + this.debugLogFinish = false; + this.memory = memory; + this.device = device; + } + /** + * Dispose context. + */ + WebGPUContext.prototype.dispose = function () { + var _a, _b, _c; + (_a = this.canvasRenderManager) === null || _a === void 0 ? void 0 : _a.dispose(); + this.bufferTableFreeId = []; + while (this.bufferTable.length != 0) { + (_b = this.bufferTable.pop()) === null || _b === void 0 ? void 0 : _b.destroy(); + } + while (this.podArgStagingBuffers.length != 0) { + (_c = this.podArgStagingBuffers.pop()) === null || _c === void 0 ? void 0 : _c.destroy(); + } + this.device.destroy(); + }; + /** + * Wait for all pending GPU tasks to complete + */ + WebGPUContext.prototype.sync = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.device.queue.onSubmittedWorkDone()]; + case 1: + _a.sent(); + return [2 /*return*/]; + } + }); + }); + }; + /** + * Obtain the runtime information in readable format. + */ + WebGPUContext.prototype.runtimeStatsText = function () { + var info = "peak-memory=" + Math.ceil(this.peakAllocatedBytes / (1 << 20)) + " MB"; + info += ", all-memory=" + Math.ceil(this.allAllocatedBytes / (1 << 20)) + " MB"; + info += ", shader-submissions=" + this.shaderSubmitCounter; + return info; + }; + /** + * Draw image from data in storage buffer. + * @param ptr The GPU ptr + * @param height The height of the image. + * @param width The width of the image. + */ + WebGPUContext.prototype.drawImageFromBuffer = function (ptr, height, width) { + if (this.canvasRenderManager == undefined) { + throw Error("Do not have a canvas context, call bindCanvas first"); + } + this.canvasRenderManager.draw(this.gpuBufferFromPtr(ptr), height, width); + }; + /** + * Copy raw bytes into buffer ptr. + * + * @param rawBytes The raw bytes + * @param toPtr The target gpu buffer ptr + * @param toOffset The beginning offset + * @param nbytes Number of bytes + */ + WebGPUContext.prototype.copyRawBytesToBuffer = function (rawBytes, toPtr, toOffset, nbytes) { + // Perhaps it would be more useful to use a staging buffer? + this.device.queue.writeBuffer(this.gpuBufferFromPtr(toPtr), toOffset, rawBytes, 0, nbytes); + }; + /** + * Clear canvas + */ + WebGPUContext.prototype.clearCanvas = function () { + var _a; + (_a = this.canvasRenderManager) === null || _a === void 0 ? void 0 : _a.clear(); + }; + /** + * Bind a canvas element to the runtime. + * @param canvas The HTML canvas/ + */ + WebGPUContext.prototype.bindCanvas = function (canvas) { + this.canvasRenderManager = new CanvaRenderManager(this.device, canvas); + }; + /** + * Create a PackedFunc that runs the given shader + * via createComputePipeline + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + WebGPUContext.prototype.createShader = function (finfo, code) { + return this.createShadeInternal(finfo, code, false); + }; + /** + * Create a PackedFunc that runs the given shader asynchrously + * via createComputePipelineAsync + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + WebGPUContext.prototype.createShaderAsync = function (finfo, code) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.createShadeInternal(finfo, code, true)]; + case 1: return [2 /*return*/, _a.sent()]; + } + }); + }); + }; + /** + * Get the pod arg staging buffer + * \param nbytes The minimum size. + * \return The allocated buffer + */ + WebGPUContext.prototype.getPodArgsBuffer = function (nbytes) { + var buffer = undefined; + if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) { + buffer = this.podArgStagingBuffers.shift(); + } + // minimum of 16 bytes + var allocSize = 16; + if (buffer !== undefined) { + allocSize = buffer.size; + if (buffer.size < nbytes) { + buffer.destroy(); + buffer = undefined; + } + } + while (allocSize < nbytes) { + allocSize *= 2; + } + if (buffer == undefined) { + // create uniform buffer + buffer = this.device.createBuffer({ + size: allocSize, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + } + assert(nbytes <= buffer.size); + return buffer; + }; + /** + * Internal impl of createShader for both async and sync mode. + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @param asyncMode Whether use async mode. + * @returns The shader function or promise of shader func. + */ + WebGPUContext.prototype.createShadeInternal = function (finfo, code, asyncMode) { + var _this = this; + var dispatchToDim = []; + var paramWriteAccess = []; + for (var i = 0; i < finfo.launch_param_tags.length; ++i) { + var tag = finfo.launch_param_tags[i]; + if (tag.startsWith("blockIdx.")) { + var target = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target); + } + else if (tag.startsWith("threadIdx.")) { + var target = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target + 3); + } + else if (tag.startsWith("paramWriteAccess:")) { + paramWriteAccess = JSON.parse(tag.substring(17)); + } + else { + throw new Error("Cannot handle thread_axis " + tag); + } + } + var layoutEntries = []; + var bufferArgIndices = []; + var podArgIndices = []; + for (var i = 0; i < finfo.arg_types.length; ++i) { + var dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ + binding: bufferArgIndices.length, + visibility: GPUShaderStage.COMPUTE, + buffer: { + type: paramWriteAccess[bufferArgIndices.length] ? "storage" : "read-only-storage" + } + }); + bufferArgIndices.push(i); + } + else if (dtype.startsWith("int") || dtype.startsWith("uint") || dtype.startsWith("float")) { + podArgIndices.push(i); + } + else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); + } + } + assert(paramWriteAccess.length == bufferArgIndices.length); + // POD arguments are pass in the end + layoutEntries.push({ + binding: bufferArgIndices.length, + visibility: GPUShaderStage.COMPUTE, + buffer: { + type: "uniform" + } + }); + var bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + var pipelineLayout = this.device.createPipelineLayout({ + bindGroupLayouts: [bindGroupLayout] + }); + // Function to create the pipeline. + var createShaderFunc = function (pipeline) { + var submitShader = function () { + var args = []; + for (var _i = 0; _i < arguments.length; _i++) { + args[_i] = arguments[_i]; + } + if (_this.debugShaderSubmitLimit != -1 && + _this.shaderSubmitCounter >= _this.debugShaderSubmitLimit) { + _this.shaderSubmitCounter += 1; + return; + } + var commandEncoder = _this.device.createCommandEncoder(); + var compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + var bindGroupEntries = []; + var numBufferOrPodArgs = bufferArgIndices.length + podArgIndices.length; + assert(args.length == numBufferOrPodArgs + dispatchToDim.length); + var workDim = [1, 1, 1, 1, 1, 1]; + for (var i = 0; i < dispatchToDim.length; ++i) { + workDim[dispatchToDim[i]] = args[numBufferOrPodArgs + i]; + } + // get around 65535 restriction of blockIdx.x + if (workDim[2] != 1) { + throw Error("WebGPU: blockIdx.z is reserved for internal use"); + } + var packDimX = workDim[0]; + // spread thinsg out into blockIdx.z + if (workDim[0] >= (1 << 16)) { + var wl_x = workDim[0]; + var wl_z = workDim[2]; + while (wl_x >= (1 << 16)) { + if (wl_x % 2 == 0) { + wl_x = wl_x / 2; + } + else { + // pad up + wl_x = (wl_x + 1) / 2; + } + wl_z *= 2; + } + workDim[0] = wl_x; + workDim[2] = wl_z; + assert(wl_x * wl_z >= packDimX); + } + for (var i = 0; i < bufferArgIndices.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: _this.gpuBufferFromPtr(args[bufferArgIndices[i]]) + } + }); + } + // push pod buffer + var sizeOfI32 = 4; + var podArgBuffer = _this.getPodArgsBuffer((podArgIndices.length + 1) * sizeOfI32); + var i32View = new Int32Array(podArgIndices.length + 1); + var u32View = new Uint32Array(i32View.buffer); + var f32View = new Float32Array(i32View.buffer); + for (var i = 0; i < podArgIndices.length; ++i) { + var value = args[podArgIndices[i]]; + var dtype = finfo.arg_types[podArgIndices[i]]; + if (dtype.startsWith("int")) { + i32View[i] = value; + } + else if (dtype.startsWith("uint")) { + u32View[i] = value; + } + else if (dtype.startsWith("float")) { + f32View[i] = value; + } + else { + throw Error("Unknown pod dtype " + dtype); + } + } + // always pass in dim z launching grid size in + u32View[podArgIndices.length] = packDimX; + _this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer); + bindGroupEntries.push({ + binding: bufferArgIndices.length, + resource: { + buffer: podArgBuffer, + size: i32View.buffer.byteLength + } + }); + compute.setBindGroup(0, _this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2]); + compute.end(); + var command = commandEncoder.finish(); + _this.device.queue.submit([command]); + if (_this.debugLogFinish) { + _this.shaderSubmitCounter; + _this.device.queue.onSubmittedWorkDone().then(function () { + // console.log("[" + currCounter + "][Debug] finish shader" + finfo.name); + }); + } + _this.shaderSubmitCounter += 1; + }; + return submitShader; + }; + var shaderModule = this.device.createShaderModule({ + code: code, + hints: { + main: { + layout: pipelineLayout + } + } + }); + if (asyncMode) { + return this.device.createComputePipelineAsync({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }).then(function (pipeline) { + return createShaderFunc(pipeline); + }); + } + else { + var pipeline = this.device.createComputePipeline({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }); + return createShaderFunc(pipeline); + } + }; + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + WebGPUContext.prototype.getDeviceAPI = function (name) { + var _this = this; + if (name == "deviceAllocDataSpace") { + return function (nbytes) { + return _this.deviceAllocDataSpace(nbytes); + }; + } + else if (name == "deviceFreeDataSpace") { + return function (ptr) { + return _this.deviceFreeDataSpace(ptr); + }; + } + else if (name == "deviceCopyToGPU") { + return function (from, to, toOffset, nbytes) { + _this.deviceCopyToGPU(from, to, toOffset, nbytes); + }; + } + else if (name == "deviceCopyFromGPU") { + return function (from, fromOffset, to, nbytes) { + _this.deviceCopyFromGPU(from, fromOffset, to, nbytes); + }; + } + else if (name == "deviceCopyWithinGPU") { + return function (from, fromOffset, to, toOffset, nbytes) { + _this.deviceCopyWithinGPU(from, fromOffset, to, toOffset, nbytes); + }; + } + else { + throw new Error("Unknown DeviceAPI function " + name); + } + }; + // DeviceAPI + WebGPUContext.prototype.deviceAllocDataSpace = function (nbytes) { + // allocate 0 bytes buffer as 1 bytes buffer. + if (nbytes == 0) { + nbytes = 1; + } + var buffer = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + this.currAllocatedBytes += nbytes; + this.allAllocatedBytes += nbytes; + if (this.currAllocatedBytes > this.peakAllocatedBytes) { + this.peakAllocatedBytes = this.currAllocatedBytes; + } + var ptr = this.attachToBufferTable(buffer); + return ptr; + }; + WebGPUContext.prototype.deviceFreeDataSpace = function (ptr) { + var idx = ptr; + var buffer = this.bufferTable[idx]; + this.bufferTable[idx] = undefined; + assert(buffer !== undefined); + this.bufferTableFreeId.push(idx); + this.currAllocatedBytes -= buffer.size; + buffer.destroy(); + }; + WebGPUContext.prototype.deviceCopyToGPU = function (from, to, toOffset, nbytes) { + // Perhaps it would be more useful to use a staging buffer? + var rawBytes = this.memory.loadRawBytes(from, nbytes); + this.device.queue.writeBuffer(this.gpuBufferFromPtr(to), toOffset, rawBytes, 0, nbytes); + }; + WebGPUContext.prototype.deviceCopyFromGPU = function (from, fromOffset, to, nbytes) { + var _this = this; + // Perhaps it would be more useful to resuse a staging buffer? + var gpuTemp = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + var copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer(this.gpuBufferFromPtr(from), fromOffset, gpuTemp, 0, nbytes); + var copyCommands = copyEncoder.finish(); + this.device.queue.submit([copyCommands]); + gpuTemp.mapAsync(GPUMapMode.READ).then(function () { + var data = gpuTemp.getMappedRange(); + _this.memory.storeRawBytes(to, new Uint8Array(data)); + gpuTemp.destroy(); + }); + }; + WebGPUContext.prototype.deviceCopyWithinGPU = function (from, fromOffset, to, toOffset, nbytes) { + var copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer(this.gpuBufferFromPtr(from), fromOffset, this.gpuBufferFromPtr(to), toOffset, nbytes); + var copyCommands = copyEncoder.finish(); + this.device.queue.submit([copyCommands]); + }; + WebGPUContext.prototype.gpuBufferFromPtr = function (ptr) { + var buffer = this.bufferTable[ptr]; + assert(buffer !== undefined); + return buffer; + }; + WebGPUContext.prototype.attachToBufferTable = function (buffer) { + if (this.bufferTableFreeId.length != 0) { + var idx = this.bufferTableFreeId.pop(); + this.bufferTable[idx] = buffer; + return idx; + } + else { + var idx = this.bufferTable.length; + this.bufferTable.push(buffer); + return idx; + } + }; + return WebGPUContext; +}()); + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +var FFILibrary = /** @class */ (function () { + function FFILibrary(wasmInstance, imports) { + this.recycledCallStacks = []; + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert(this.wasmInstance.exports !== undefined, "Expect the library module contains exports"); + this.exports = this.wasmInstance.exports; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + FFILibrary.prototype.dispose = function () { + var _a; + while (this.recycledCallStacks.length != 0) { + this.recycledCallStacks.pop().dispose(); + } + (_a = this.webGPUContext) === null || _a === void 0 ? void 0 : _a.dispose(); + }; + FFILibrary.prototype.sizeofPtr = function () { + return this.memory.sizeofPtr(); + }; + FFILibrary.prototype.checkCall = function (code) { + if (code != 0) { + var msgPtr = this.exports + .TVMGetLastError(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + }; + FFILibrary.prototype.getOrAllocCallStack = function () { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop(); + } + return new CachedCallStack(this.memory, this.exports.TVMWasmAllocSpace, this.exports.TVMWasmFreeSpace); + }; + FFILibrary.prototype.recycleCallStack = function (callstack) { + callstack.reset(); + this.recycledCallStacks.push(callstack); + }; + FFILibrary.prototype.validateInstance = function () { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + }; + FFILibrary.prototype.checkExports = function (funcNames) { + var missList = []; + for (var _i = 0, funcNames_1 = funcNames; _i < funcNames_1.length; _i++) { + var name_1 = funcNames_1[_i]; + var f = this.exports[name_1]; + if (!(f instanceof Function)) { + missList.push(name_1); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + }; + FFILibrary.prototype.detectWasmMemory = function (instance, imports) { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + throw new Error("Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports); + }; + return FFILibrary; +}()); +/** + * @internal + * Manages extra runtime context for the runtime. + */ +var RuntimeContext = /** @class */ (function () { + function RuntimeContext(getGlobalFunc) { + this.autoDisposeScope = []; + this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); + this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); + this.arrayMake = getGlobalFunc("runtime.Array"); + this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get"); + this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove"); + this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update"); + this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear"); + this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); + this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); + this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); + this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); + this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); + } + RuntimeContext.prototype.dispose = function () { + // call array cache clear to clear all cached items + this.arrayCacheClear.dispose(); + this.arrayGetItem.dispose(); + this.arrayGetSize.dispose(); + this.arrayMake.dispose(); + this.arrayCacheGet.dispose(); + this.arrayCacheRemove.dispose(); + this.arrayCacheUpdate.dispose(); + this.arrayCacheClear.dispose(); + this.arrayDecodeStorage.dispose(); + this.paramModuleFromCache.dispose(); + this.makeShapeTuple.dispose(); + this.ndarrayCreateView.dispose(); + this.sampleTopPFromLogits.dispose(); + }; + RuntimeContext.prototype.beginScope = function () { + this.autoDisposeScope.push([]); + }; + RuntimeContext.prototype.endScope = function () { + if (this.autoDisposeScope.length == 0) { + throw Error("tvm.endScope called when the stack is empty."); + } + // automatically dispose all the tracked values in the current scope. + var currScope = this.autoDisposeScope.pop(); + for (var i = 0; i < currScope.length; ++i) { + var val = currScope[i]; + if (val !== undefined) { + val.dispose(); + } + } + }; + /** + * Track object for dispose in current scope. + * + * @param obj The object to be tracked. + * @returns the same object. + * @note This function only needs to be called for raw system C API values. + * The return value of PackedFunc will be automatically tracked. + */ + RuntimeContext.prototype.attachToCurrentScope = function (obj) { + if (this.autoDisposeScope.length == 0) { + throw Error("Must call beginScope to use functions that returns TVM objects"); + } + var currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + currScope.push(obj); + return obj; + }; + RuntimeContext.prototype.moveToParentScope = function (obj) { + this.detachFromCurrentScope(obj); + if (this.autoDisposeScope.length < 2) { + throw Error("moveToParentScope: Parent scope do not exist"); + } + var parentScope = this.autoDisposeScope[this.autoDisposeScope.length - 2]; + parentScope.push(obj); + return obj; + }; + RuntimeContext.prototype.detachFromCurrentScope = function (obj) { + var currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + var occurance = 0; + for (var i = 0; i < currScope.length; ++i) { + if (currScope[i] === obj) { + occurance += 1; + currScope[i] = undefined; + } + } + if (occurance == 0) { + throw Error("Cannot find obj in the current auto conversion pool"); + } + if (occurance > 1) { + throw Error("Value attached to scope multiple times"); + } + return obj; + }; + return RuntimeContext; +}()); +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +var Scalar = /** @class */ (function () { + function Scalar(value, dtype) { + this.value = value; + this.dtype = dtype; + } + return Scalar; +}()); +/** + * Cell holds the PackedFunc object. + */ +var PackedFuncCell = /** @class */ (function () { + function PackedFuncCell(handle, lib) { + this.handle = handle; + this.lib = lib; + } + PackedFuncCell.prototype.dispose = function () { + if (this.handle != 0) { + this.lib.checkCall(this.lib.exports.TVMFuncFree(this.handle)); + this.handle = 0; + } + }; + PackedFuncCell.prototype.getHandle = function (requireNotNull) { + if (requireNotNull === void 0) { requireNotNull = true; } + if (requireNotNull && this.handle == 0) { + throw Error("PackedFunc has already been disposed"); + } + return this.handle; + }; + return PackedFuncCell; +}()); +var DeviceEnumToStr = { + 1: "cpu", + 2: "cuda", + 4: "opencl", + 8: "metal", + 15: "webgpu" +}; +var DeviceStrToEnum = { + cpu: 1, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, + webgpu: 15 +}; +/** + * Represent a runtime context where a NDArray can reside. + */ +var DLDevice = /** @class */ (function () { + function DLDevice(deviceType, deviceId, lib) { + var tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + if (this.deviceType == undefined) { + throw new Error("Cannot recogonize deviceType " + deviceType); + } + } + else if (tp == "number") { + this.deviceType = deviceType; + } + else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + /** + * Synchronize the device + */ + DLDevice.prototype.sync = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (!(this.deviceType == DeviceStrToEnum.webgpu)) return [3 /*break*/, 2]; + assert(this.lib.webGPUContext !== undefined); + return [4 /*yield*/, this.lib.webGPUContext.sync()]; + case 1: + _a.sent(); + _a.label = 2; + case 2: return [2 /*return*/]; + } + }); + }); + }; + DLDevice.prototype.toString = function () { + return (DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")"); + }; + return DLDevice; +}()); +/** + * The data type code in DLDataType + */ +var DLDataTypeCode; +(function (DLDataTypeCode) { + DLDataTypeCode[DLDataTypeCode["Int"] = 0] = "Int"; + DLDataTypeCode[DLDataTypeCode["UInt"] = 1] = "UInt"; + DLDataTypeCode[DLDataTypeCode["Float"] = 2] = "Float"; + DLDataTypeCode[DLDataTypeCode["OpaqueHandle"] = 3] = "OpaqueHandle"; +})(DLDataTypeCode || (DLDataTypeCode = {})); +var DLDataTypeCodeToStr = { + 0: "int", + 1: "uint", + 2: "float", + 3: "handle", +}; +/** + * Runtime data type of NDArray. + */ +var DLDataType = /** @class */ (function () { + function DLDataType(code, bits, lanes) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + DLDataType.prototype.toString = function () { + var ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } + else { + return ret; + } + }; + DLDataType.prototype.numStorageBytes = function () { + return (this.bits * this.lanes + 7) >> 3; + }; + return DLDataType; +}()); +/** + * n-dimnesional array. + */ +var NDArray = /** @class */ (function () { + function NDArray(handle, isView, lib, ctx) { + this.handle = handle; + this.isView = isView; + this.lib = lib; + this.ctx = ctx; + if (this.isView) { + this.dltensor = handle; + } + else { + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + } + // constant offsets. + var arrayOffsetData = 0; + var arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + var arrayOffsetDevType = arrayOffsetContext; + var arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + var arrayOffsetNdim = arrayOffsetContext + SizeOf.DLDevice; + var arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + var arrayOffsetDtypeCode = arrayOffsetDtype; + var arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + var arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + var arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + var arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + var arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // dataPtr + this.dataPtr = lib.memory.loadPointer(this.dltensor); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + var cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (var i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + var code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + var bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + var lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + // device + var deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + var deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.device = new DLDevice(deviceType, deviceId, lib); + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + /** + * Create a view of the array. + * @param shape The shape of the view. + * @returns The new sliced ndarray. + */ + NDArray.prototype.view = function (shape) { + var _a; + var shapeArray = shape.map(function (value) { return new Scalar(value, "int"); }); + return this.ctx.ndarrayCreateView(this, (_a = this.ctx).makeShapeTuple.apply(_a, shapeArray)); + }; + /** + * Get handle of ndarray, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + NDArray.prototype.getHandle = function (requireNotNull) { + if (requireNotNull === void 0) { requireNotNull = true; } + if (requireNotNull && this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.handle; + }; + /** + * Get dataPtr of NDarray + * + * @returns The handle. + */ + NDArray.prototype.getDataPtr = function () { + if (this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.dataPtr; + }; + NDArray.prototype.dispose = function () { + if (this.handle != 0 && !this.isView) { + this.lib.checkCall(this.lib.exports.TVMArrayFree(this.handle)); + this.handle = 0; + } + }; + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + NDArray.prototype.copyFrom = function (data) { + if (data instanceof NDArray) { + this.lib.checkCall(this.lib.exports.TVMArrayCopyFromTo(data.getHandle(), this.getHandle(), 0)); + return this; + } + else { + var size = this.shape.reduce(function (a, b) { + return a * b; + }, 1); + if (data.length != size) { + throw new Error("data size and shape mismatch data.length" + + data.length + + " vs " + + size); + } + var buffer = void 0; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } + else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } + else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } + else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } + else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } + else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + }; + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + NDArray.prototype.copyFromRawBytes = function (data) { + var _a; + // short cut for gpu copy + if (this.device.deviceType == DeviceStrToEnum.webgpu) { + (_a = this.lib.webGPUContext) === null || _a === void 0 ? void 0 : _a.copyRawBytesToBuffer(data, this.getDataPtr(), 0, data.length); + return this; + } + // CPU copy + var size = this.shape.reduce(function (a, b) { + return a * b; + }, 1); + var nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + var stack = this.lib.getOrAllocCallStack(); + var tempOffset = stack.allocRawBytes(nbytes); + var tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall(this.lib.exports.TVMArrayCopyFromBytes(this.getHandle(), tempPtr, nbytes)); + this.lib.recycleCallStack(stack); + return this; + }; + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + NDArray.prototype.toRawBytes = function () { + if (this.device.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); + } + var size = this.shape.reduce(function (a, b) { + return a * b; + }, 1); + var nbytes = this.dlDataType.numStorageBytes() * size; + var stack = this.lib.getOrAllocCallStack(); + var tempOffset = stack.allocRawBytes(nbytes); + var tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall(this.lib.exports.TVMArrayCopyToBytes(this.getHandle(), tempPtr, nbytes)); + var ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + this.lib.recycleCallStack(stack); + return ret; + }; + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + NDArray.prototype.toArray = function () { + var stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } + else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } + else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } + else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } + else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } + else { + throw new Error("Unsupported data type " + this.dtype); + } + }; + NDArray.prototype.getDLTensorFromArrayHandle = function (handle) { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + }; + return NDArray; +}()); +/** + * Runtime Module. + */ +var Module = /** @class */ (function () { + function Module(handle, lib, makePackedFunc) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + Module.prototype.dispose = function () { + if (this.handle != 0) { + this.lib.checkCall(this.lib.exports.TVMModFree(this.handle)); + this.handle = 0; + } + }; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + Module.prototype.getHandle = function (requireNotNull) { + if (requireNotNull === void 0) { requireNotNull = true; } + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + }; + /** + * Get a function in the module. + * @param name The name of the function. + * @param queryImports Whether to also query imports + * @returns The result function. + */ + Module.prototype.getFunction = function (name, queryImports) { + if (queryImports === void 0) { queryImports = true; } + if (this.handle == 0) { + throw Error("Module has already been disposed"); + } + var stack = this.lib.getOrAllocCallStack(); + var nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + this.lib.checkCall(this.lib.exports.TVMModGetFunction(this.getHandle(), stack.ptrFromOffset(nameOffset), queryImports ? 1 : 0, outPtr)); + var handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + var ret = this.makePackedFunc(handle); + return ret; + }; + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + Module.prototype.importModule = function (mod) { + this.lib.checkCall(this.lib.exports.TVMModImport(this.getHandle(), mod.getHandle())); + }; + return Module; +}()); +/** + * Generic object base + */ +var TVMObject = /** @class */ (function () { + function TVMObject(handle, lib, ctx) { + this.handle = handle; + this.lib = lib; + this.ctx = ctx; + } + TVMObject.prototype.dispose = function () { + if (this.handle != 0) { + this.lib.checkCall(this.lib.exports.TVMObjectFree(this.handle)); + this.handle = 0; + } + }; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + TVMObject.prototype.getHandle = function (requireNotNull) { + if (requireNotNull === void 0) { requireNotNull = true; } + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + }; + /** get the type index of the object */ + TVMObject.prototype.typeIndex = function () { + if (this.handle == 0) { + throw Error("The current Object has already been disposed"); + } + var stack = this.lib.getOrAllocCallStack(); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall(this.lib.exports.TVMObjectGetTypeIndex(this.getHandle(), outPtr)); + var result = this.lib.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return result; + }; + /** get the type key of the object */ + TVMObject.prototype.typeKey = function () { + var type_index = this.typeIndex(); + var stack = this.lib.getOrAllocCallStack(); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall(this.lib.exports.TVMObjectTypeIndex2Key(type_index, outPtr)); + var result = this.lib.memory.loadCString(this.lib.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return result; + }; + return TVMObject; +}()); +/** Runtime array object. */ +var TVMArray = /** @class */ (function (_super) { + __extends(TVMArray, _super); + function TVMArray(handle, lib, ctx) { + return _super.call(this, handle, lib, ctx) || this; + } + /** + * @returns the size of the array. + */ + TVMArray.prototype.size = function () { + return this.ctx.arrayGetSize(this); + }; + /** + * Get index-th element of the array + * @param index the array index. + * @returns The element. + */ + TVMArray.prototype.get = function (index) { + return this.ctx.arrayGetItem(this, new Scalar(index, "int32")); + }; + return TVMArray; +}(TVMObject)); +var VMAllocatorKind; +(function (VMAllocatorKind) { + VMAllocatorKind[VMAllocatorKind["NAIVE_ALLOCATOR"] = 1] = "NAIVE_ALLOCATOR"; + VMAllocatorKind[VMAllocatorKind["POOLED_ALLOCATOR"] = 2] = "POOLED_ALLOCATOR"; +})(VMAllocatorKind || (VMAllocatorKind = {})); +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +var VirtualMachine = /** @class */ (function () { + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + function VirtualMachine(mod, device) { + this.mod = mod; + this.mod.getFunction("vm_initialization")(new Scalar(device.deviceType, "int"), new Scalar(device.deviceId, "int"), new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"), + // explicitly specify host device type + new Scalar(DeviceStrToEnum.cpu, "int"), new Scalar(0, "int"), new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int")); + } + VirtualMachine.prototype.dispose = function () { + this.mod.dispose(); + }; + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + VirtualMachine.prototype.getFunction = function (name) { + return this.mod.getFunction(name); + }; + /** + * Get the internal module. + */ + VirtualMachine.prototype.getInternalModule = function () { + return this.mod; + }; + return VirtualMachine; +}()); +/** Code used as the first argument of the async callback. */ +var AyncCallbackCode; +(function (AyncCallbackCode) { + AyncCallbackCode[AyncCallbackCode["kReturn"] = 4] = "kReturn"; + AyncCallbackCode[AyncCallbackCode["kException"] = 5] = "kException"; +})(AyncCallbackCode || (AyncCallbackCode = {})); +/** + * TVM runtime instance. + * + * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * and PackedFunc instance are tracked through a scope mechanism that will get + * auto-released when we call EndScope. + * + * This is necessarily to be able to release the underlying WASM and WebGPU memory that + * are not tracked through JS native garbage collection mechanism. + * + * This does mean that we have to get familar with the following functions: + * - {@link beginScope} + * - {@link endScope} + * - {@link withNewScope} + * - {@link attachToCurrentScope} + * - {@link detachFromCurrentScope} + */ +var Instance = /** @class */ (function () { + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + function Instance(wasmModule, importObject, wasmInstance, env) { + if (importObject === void 0) { importObject = {}; } + var _this = this; + this.cacheMetadata = {}; + this.initProgressCallback = []; + if (wasmInstance instanceof WebAssembly.Instance) { + assert(env instanceof Environment, "env must be provided when passing in instance"); + } + else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.objFactory = new Map(); + this.ctx = new RuntimeContext(function (name) { + var autoAttachToScope = false; + // runtime context function do not auto-release. + return _this.getGlobalFuncInternal(name, autoAttachToScope); + }); + this.registerEnvGlobalPackedFuncs(); + this.registerObjectFactoryFuncs(); + } + /** + * Benchmark stable execution of the run function. + * + * @params run The run function + * @params dev The device to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + Instance.prototype.benchmark = function (run, dev, number, repeat) { + if (number === void 0) { number = 10; } + if (repeat === void 0) { repeat = 1; } + return __awaiter(this, void 0, void 0, function () { + var perf, results, k, tstart, i, tend; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + perf = getPerformance(); + results = []; + // run with new scope + this.withNewScope(run); + return [4 /*yield*/, dev.sync()]; + case 1: + _a.sent(); + k = 0; + _a.label = 2; + case 2: + if (!(k < repeat)) return [3 /*break*/, 5]; + tstart = perf.now(); + for (i = 0; i < number; ++i) { + this.withNewScope(run); + } + return [4 /*yield*/, dev.sync()]; + case 3: + _a.sent(); + tend = perf.now(); + results.push((tend - tstart) / number); + _a.label = 4; + case 4: + ++k; + return [3 /*break*/, 2]; + case 5: return [2 /*return*/, results]; + } + }); + }); + }; + Instance.prototype.dispose = function () { + // order matters + // ctx release goes back into lib. + this.ctx.dispose(); + this.lib.dispose(); + }; + /** + * Obtain the runtime information in readable format. + */ + Instance.prototype.runtimeStatsText = function () { + if (this.lib.webGPUContext !== undefined) { + return this.lib.webGPUContext.runtimeStatsText(); + } + else { + return ""; + } + }; + /** + * Begin a new scope for tracking object disposal. + */ + Instance.prototype.beginScope = function () { + this.ctx.beginScope(); + }; + /** + * End a scope and release all created TVM objects + * under the current scope. + * + * Exception: one can call {@link moveToParentScope} to move + * a value to parent scope. + */ + Instance.prototype.endScope = function () { + this.ctx.endScope(); + }; + /** + * Perform action under a new scope. + * + * @param action The action function. + * @returns The result value. + * + * @note For action to return a valid value, + * we will need to call {@link moveToParentScope} + * for the objects that are created in the scope. + */ + Instance.prototype.withNewScope = function (action) { + this.beginScope(); + var val = action(); + this.endScope(); + return val; + }; + /** + * Attach a detached obj to the auto-release pool of the current scope. + * + * @param obj The input obj. + * @note Normally user do not need to call this function explicitly, as + * all library call return values are explicitly attached to + * the current scope. You only need to do so when you call + * {@link detachFromCurrentScope} to create a detached object. + */ + Instance.prototype.attachToCurrentScope = function (obj) { + return this.ctx.attachToCurrentScope(obj); + }; + /** + * Move obj's attachment to the parent scope. + * + * This function is useful to make sure objects are still + * alive when exit the current scope. + * + * @param obj The object to be moved. + * @returns The input obj. + */ + Instance.prototype.moveToParentScope = function (obj) { + return this.ctx.moveToParentScope(obj); + }; + /** + * Detach the object from the current scope + * so it won't be released via auto-release during endscope. + * + * User needs to either explicitly call obj.dispose(), or + * {@link attachToCurrentScope} to re-attach to the current scope. + * + * This function can be used to return values to the parent scope. + * @param obj The object. + */ + Instance.prototype.detachFromCurrentScope = function (obj) { + return this.ctx.detachFromCurrentScope(obj); + }; + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + Instance.prototype.systemLib = function () { + return this.ctx.getSysLib(); + }; + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + Instance.prototype.listGlobalFuncNames = function () { + var stack = this.lib.getOrAllocCallStack(); + var outSizeOffset = stack.allocPtrArray(2); + var outSizePtr = stack.ptrFromOffset(outSizeOffset); + var outArrayPtr = stack.ptrFromOffset(outSizeOffset + this.lib.sizeofPtr()); + this.lib.checkCall(this.exports.TVMFuncListGlobalNames(outSizePtr, outArrayPtr)); + var size = this.memory.loadI32(outSizePtr); + var array = this.memory.loadPointer(outArrayPtr); + var names = []; + for (var i = 0; i < size; ++i) { + names.push(this.memory.loadCString(this.memory.loadPointer(array + this.lib.sizeofPtr() * i))); + } + this.lib.recycleCallStack(stack); + return names; + }; + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + Instance.prototype.registerFunc = function (name, func, override) { + var _this = this; + if (override === void 0) { override = false; } + this.withNewScope(function () { + var autoAttachToScope = true; + // packed func can be released once it is registered + var packedFunc = _this.toPackedFuncInternal(func, autoAttachToScope); + var ioverride = override ? 1 : 0; + var stack = _this.lib.getOrAllocCallStack(); + var nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + _this.lib.checkCall(_this.lib.exports.TVMFuncRegisterGlobal(stack.ptrFromOffset(nameOffset), packedFunc._tvmPackedCell.getHandle(), ioverride)); + _this.lib.recycleCallStack(stack); + }); + }; + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @param autoAttachToScope Whether to track it via autoDispose + * @returns The result function. + */ + Instance.prototype.getGlobalFunc = function (name) { + return this.getGlobalFuncInternal(name, true); + }; + Instance.prototype.getGlobalFuncInternal = function (name, autoAttachToScope) { + if (autoAttachToScope === void 0) { autoAttachToScope = true; } + var stack = this.lib.getOrAllocCallStack(); + var nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + this.lib.checkCall(this.exports.TVMFuncGetGlobal(stack.ptrFromOffset(nameOffset), outPtr)); + var handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + var ret = this.makePackedFunc(handle); + if (autoAttachToScope) + this.ctx.attachToCurrentScope(ret); + return ret; + }; + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + Instance.prototype.isPackedFunc = function (func) { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + }; + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + Instance.prototype.toPackedFunc = function (func) { + return this.toPackedFuncInternal(func, true); + }; + Instance.prototype.toPackedFuncInternal = function (func, autoAttachToScope) { + if (this.isPackedFunc(func)) + return func; + var ret = this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + if (autoAttachToScope) + return this.ctx.attachToCurrentScope(ret); + return ret; + }; + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + Instance.prototype.createVirtualMachine = function (dev) { + var mod = this.ctx.detachFromCurrentScope(this.systemLib().getFunction("vm_load_executable")()); + return this.ctx.attachToCurrentScope(new VirtualMachine(mod, dev)); + }; + //----------------------------------------------- + // Native NDArray Cache Support + //----------------------------------------------- + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + Instance.prototype.registerInitProgressCallback = function (cb) { + this.initProgressCallback.push(cb); + }; + /** + * Get parameters in the form of prefix_i + * + * @param prefix The parameter prefix. + * @param numParams Number of parameters. + * @returns + */ + Instance.prototype.getParamsFromCache = function (prefix, numParams) { + return this.ctx.paramModuleFromCache(prefix, new Scalar(numParams, "int32")).getFunction("get_params")(); + }; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + Instance.prototype.ndarrayCacheGet = function (name) { + return this.ctx.arrayCacheGet(name); + }; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + Instance.prototype.ndarrayCacheRemove = function (name) { + return this.ctx.arrayCacheRemove(name); + }; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + Instance.prototype.ndarrayCacheUpdate = function (name, arr, override) { + if (override === void 0) { override = false; } + this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); + }; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + Instance.prototype.ndarrayCacheClear = function () { + this.ctx.arrayCacheClear(); + }; + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + * @returns The meta data + */ + Instance.prototype.fetchNDArrayCache = function (ndarrayCacheUrl, device) { + return __awaiter(this, void 0, void 0, function () { + var jsonUrl, request, cache, result, list; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + request = new Request(jsonUrl); + return [4 /*yield*/, caches.open("tvmjs")]; + case 1: + cache = _a.sent(); + return [4 /*yield*/, cache.match(request)]; + case 2: + result = _a.sent(); + if (!(result === undefined)) return [3 /*break*/, 5]; + return [4 /*yield*/, cache.add(request)]; + case 3: + _a.sent(); + return [4 /*yield*/, cache.match(request)]; + case 4: + result = _a.sent(); + _a.label = 5; + case 5: + if (!(result === undefined)) return [3 /*break*/, 9]; + this.env.logger("Error: Cannot cache " + jsonUrl + ", reloading will be slow"); + _a.label = 6; + case 6: + _a.trys.push([6, 8, , 9]); + return [4 /*yield*/, fetch(request)]; + case 7: + result = _a.sent(); + return [3 /*break*/, 9]; + case 8: + _a.sent(); + this.env.logger("Cannot fetch " + jsonUrl); + return [3 /*break*/, 9]; + case 9: + if (!(result instanceof Response)) return [3 /*break*/, 11]; + return [4 /*yield*/, result.json()]; + case 10: + list = _a.sent(); + _a.label = 11; + case 11: return [4 /*yield*/, this.fetchNDArrayCacheInternal(ndarrayCacheUrl, list["records"], device)]; + case 12: + _a.sent(); + this.cacheMetadata = __assign(__assign({}, this.cacheMetadata), list["metadata"]); + return [2 /*return*/]; + } + }); + }); + }; + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + Instance.prototype.fetchNDArrayCacheInternal = function (ndarrayCacheUrl, list, device) { + return __awaiter(this, void 0, void 0, function () { + var perf, tstart, totalBytes, i, fetchedBytes, timeElapsed, reportCallback, j, cache, i, dataUrl, request, buffer, result, err_2, shardRecords, _loop_1, this_1, j; + var _this = this; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + perf = getPerformance(); + tstart = perf.now(); + totalBytes = 0; + for (i = 0; i < list.length; ++i) { + totalBytes += list[i].nbytes; + } + fetchedBytes = 0; + timeElapsed = 0; + reportCallback = function (iter) { + // report + for (var j = 0; j < _this.initProgressCallback.length; ++j) { + _this.initProgressCallback[j]({ + type: 'init', + progress: fetchedBytes / totalBytes, + timeElapsed: timeElapsed, + currentChunk: iter, + totalChunks: list.length, + fetchedBytes: fetchedBytes, + totalBytes: totalBytes, + }); + } + }; + for (j = 0; j < this.initProgressCallback.length; ++j) { + this.initProgressCallback[j]({ + type: 'init', + progress: fetchedBytes / totalBytes, + timeElapsed: 0, + currentChunk: 0, + totalChunks: list.length, + fetchedBytes: fetchedBytes, + totalBytes: totalBytes, + }); + } + return [4 /*yield*/, caches.open("tvmjs")]; + case 1: + cache = _a.sent(); + i = 0; + _a.label = 2; + case 2: + if (!(i < list.length)) return [3 /*break*/, 18]; + reportCallback(i); + fetchedBytes += list[i].nbytes; + dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href; + request = new Request(dataUrl); + buffer = void 0; + _a.label = 3; + case 3: + _a.trys.push([3, 11, , 12]); + return [4 /*yield*/, cache.match(request)]; + case 4: + result = _a.sent(); + if (!(result === undefined)) return [3 /*break*/, 7]; + return [4 /*yield*/, cache.add(request)]; + case 5: + _a.sent(); + return [4 /*yield*/, cache.match(request)]; + case 6: + result = _a.sent(); + _a.label = 7; + case 7: + if (!(result == undefined)) return [3 /*break*/, 9]; + this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will be slow"); + return [4 /*yield*/, fetch(request)]; + case 8: + result = _a.sent(); + _a.label = 9; + case 9: return [4 /*yield*/, result.arrayBuffer()]; + case 10: + buffer = _a.sent(); + return [3 /*break*/, 12]; + case 11: + err_2 = _a.sent(); + this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err_2); + throw err_2; + case 12: + shardRecords = list[i].records; + _loop_1 = function (j) { + var rec, cpu_arr, recSource, gpu_arr; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + rec = shardRecords[j]; + cpu_arr = this_1.withNewScope(function () { + return _this.detachFromCurrentScope(_this.empty(rec.shape, rec.dtype, _this.cpu())); + }); + recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + // first sync copy to cpu. + this_1.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format); + if (!(device.deviceType == DeviceStrToEnum.cpu)) return [3 /*break*/, 1]; + this_1.ndarrayCacheUpdate(rec.name, cpu_arr, false); + cpu_arr.dispose(); + return [3 /*break*/, 3]; + case 1: + gpu_arr = this_1.withNewScope(function () { + return _this.detachFromCurrentScope(_this.empty(rec.shape, rec.dtype, device)); + }); + gpu_arr.copyFrom(cpu_arr); + return [4 /*yield*/, device.sync()]; + case 2: + _b.sent(); + this_1.ndarrayCacheUpdate(rec.name, gpu_arr, false); + cpu_arr.dispose(); + gpu_arr.dispose(); + _b.label = 3; + case 3: return [2 /*return*/]; + } + }); + }; + this_1 = this; + j = 0; + _a.label = 13; + case 13: + if (!(j < shardRecords.length)) return [3 /*break*/, 16]; + return [5 /*yield**/, _loop_1(j)]; + case 14: + _a.sent(); + _a.label = 15; + case 15: + ++j; + return [3 /*break*/, 13]; + case 16: + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + _a.label = 17; + case 17: + ++i; + return [3 /*break*/, 2]; + case 18: + reportCallback(list.length); + return [2 /*return*/]; + } + }); + }); + }; + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + Instance.prototype.toDLDataType = function (dtype) { + if (dtype instanceof DLDataType) + return dtype; + if (typeof dtype == "string") { + var pattern = dtype; + var code = void 0, bits = 32, lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.Float; + } + else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = DLDataTypeCode.Int; + } + else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = DLDataTypeCode.UInt; + } + else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.OpaqueHandle; + bits = 64; + } + else { + throw new Error("Unknown dtype " + dtype); + } + var arr = pattern.split("x"); + if (arr.length >= 1) { + var parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } + else { + throw new Error("Unknown dtype " + dtype); + } + }; + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + Instance.prototype.scalar = function (value, dtype) { + return new Scalar(value, dtype); + }; + /** + * Create a new {@link DLDevice} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created device. + */ + Instance.prototype.device = function (deviceType, deviceId) { + if (deviceId === void 0) { deviceId = 0; } + return new DLDevice(deviceType, deviceId, this.lib); + }; + /** + * Create a new cpu {@link DLDevice} + * @param deviceId The device index. + */ + Instance.prototype.cpu = function (deviceId) { + if (deviceId === void 0) { deviceId = 0; } + return this.device("cpu", deviceId); + }; + /** + * Create a new webgpu {@link DLDevice} + * @param deviceId The device index. + */ + Instance.prototype.webgpu = function (deviceId) { + if (deviceId === void 0) { deviceId = 0; } + return this.device("webgpu", deviceId); + }; + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + Instance.prototype.empty = function (shape, dtype, dev) { + if (dtype === void 0) { dtype = "float32"; } + if (dev === void 0) { dev = this.device("cpu", 0); } + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + var stack = this.lib.getOrAllocCallStack(); + var shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (var i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + this.lib.checkCall(this.exports.TVMArrayAlloc(stack.ptrFromOffset(shapeOffset), shape.length, dtype.code, dtype.bits, dtype.lanes, dev.deviceType, dev.deviceId, outPtr)); + var ret = this.ctx.attachToCurrentScope(new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx)); + this.lib.recycleCallStack(stack); + return ret; + }; + /** + * Create am uniform {@link NDArray} with given shape. + * + * @param shape The shape of the array. + * @param low The low value. + * @param high The high value. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + Instance.prototype.uniform = function (shape, low, high, dev) { + var ret = this.empty(shape, "float32", dev); + var size = shape.reduce(function (a, b) { + return a * b; + }, 1); + var scale = high - low; + var input = new Float32Array(size); + for (var i = 0; i < input.length; ++i) { + input[i] = low + Math.random() * scale; + } + return ret.copyFrom(input); + }; + /** + * Sample index via top-p sampling. + * + * @param logits The input logits before normalization. + * @param temperature The temperature factor, will take argmax if temperature = 0.0 + * @param top_p The top_p + * @returns The sampled index. + */ + Instance.prototype.sampleTopPFromLogits = function (logits, temperature, top_p) { + return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random()); + }; + /** + * Bind canvas to the current WebGPU context + * @param canvas The canvas. + */ + Instance.prototype.bindCanvas = function (canvas) { + var _a; + (_a = this.lib.webGPUContext) === null || _a === void 0 ? void 0 : _a.bindCanvas(canvas); + }; + /** + * Show image in canvas. + * + * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + */ + Instance.prototype.showImage = function (dataRGBA) { + var _a; + if (dataRGBA.shape.length != 2) { + throw Error("Require a height x width uint32 NDArray in RGBA" + + "get shape=" + dataRGBA.shape.toString() + " instead."); + } + if (dataRGBA.device.deviceType != DeviceStrToEnum.webgpu) { + throw new Error("Can only run showImage on WebGPU array, " + + "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead."); + } + if (dataRGBA.dtype != "uint32") { + throw Error("Require a height x width uint32 NDArray in RGBA, " + + "get " + dataRGBA.dtype + " instead."); + } + (_a = this.lib.webGPUContext) === null || _a === void 0 ? void 0 : _a.drawImageFromBuffer(dataRGBA.getDataPtr(), dataRGBA.shape[0], dataRGBA.shape[1]); + }; + /** + * Clear canvas + */ + Instance.prototype.clearCanvas = function () { + var _a; + (_a = this.lib.webGPUContext) === null || _a === void 0 ? void 0 : _a.clearCanvas(); + }; + /** + * Create an tuple {@link TVMArray} input array. + * + * The input array can be passed to tvm runtime function + * and needs to b explicitly disposed. + * + * @param inputs The input array + * @returns The result array. + */ + Instance.prototype.makeTVMArray = function (inputs) { + var _a; + return (_a = this.ctx).arrayMake.apply(_a, inputs); + }; + /** + * Create a shape tuple to pass to runtime. + * @param shape The shape . + * @returns The created shape tuple. + */ + Instance.prototype.makeShapeTuple = function (shape) { + var _a; + var shapeArray = shape.map(function (value) { return new Scalar(value, "int"); }); + return (_a = this.ctx).makeShapeTuple.apply(_a, shapeArray); + }; + /** + * Get type index from type key. + * @param typeKey The type key. + * @returns The corresponding type index. + */ + Instance.prototype.typeKey2Index = function (typeKey) { + var stack = this.lib.getOrAllocCallStack(); + var typeKeyOffset = stack.allocRawBytes(typeKey.length + 1); + stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey)); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + this.lib.checkCall(this.lib.exports.TVMObjectTypeKey2Index(stack.ptrFromOffset(typeKeyOffset), outPtr)); + var typeIndex = this.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return typeIndex; + }; + /** + * Register an object constructor. + * @param typeKey The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + */ + Instance.prototype.registerObjectConstructor = function (typeKey, func, override) { + if (override === void 0) { override = false; } + var typeIndex = this.typeKey2Index(typeKey); + if (this.objFactory.has(typeIndex)) { + if (!override) { + throw new Error("Type " + typeKey + " already registered"); + } + } + this.objFactory.set(typeIndex, func); + }; + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + Instance.prototype.registerAsyncServerFunc = function (name, func, override) { + var _this = this; + if (override === void 0) { override = false; } + var asyncVariant = function () { + var args = []; + for (var _i = 0; _i < arguments.length; _i++) { + args[_i] = arguments[_i]; + } + var fargs = args.slice(0, args.length - 1); + // need to keep it alive until callback is fulfilled. + var callback = _this.detachFromCurrentScope(args[args.length - 1]); + var promise = func.apply(void 0, fargs); + promise.then(function (rv) { + callback(_this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + callback.dispose(); + }); + }; + this.registerFunc("__async." + name, asyncVariant, override); + }; + /** + * Asynchrously load webgpu pipelines when possible. + * @param mod The input module. + */ + Instance.prototype.asyncLoadWebGPUPiplines = function (mod) { + return __awaiter(this, void 0, void 0, function () { + var webgpuContext, fmap_str, fmap, fGetShader, fUpdatePrebuild, perf, tstart, tlastReport, finishCounter, fmapEntries, allEvents, _loop_2, _i, fmapEntries_1, _a, key, finfo; + var _this = this; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + if (this.lib.webGPUContext == undefined) + throw Error("WebGPU not initialied"); + webgpuContext = this.lib.webGPUContext; + this.beginScope(); + fmap_str = mod.getFunction("webgpu.get_fmap", true)(); + fmap = JSON.parse(fmap_str); + fmap.length; + fGetShader = this.detachFromCurrentScope(mod.getFunction("webgpu.get_shader")); + fUpdatePrebuild = this.detachFromCurrentScope(mod.getFunction("webgpu.update_prebuild")); + this.endScope(); + perf = getPerformance(); + tstart = perf.now(); + tlastReport = tstart; + finishCounter = 0; + fmapEntries = Object.entries(fmap); + allEvents = Promise.resolve(); + _loop_2 = function (key, finfo) { + var code = fGetShader(key); + assert(key == finfo.name); + var event_1 = webgpuContext.createShaderAsync(finfo, code).then(function (func) { + _this.beginScope(); + fUpdatePrebuild(key, func); + _this.endScope(); + }).then(function () { + finishCounter += 1; + var tend = perf.now(); + // skip report if gap is smaller than 1000 + if ((tend - tlastReport) < 1000 && finishCounter != fmapEntries.length) { + return; + } + tlastReport = tend; + var timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + // report + for (var j = 0; j < _this.initProgressCallback.length; ++j) { + var progress = finishCounter / fmapEntries.length; + var text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length + "]: "; + text += Math.floor(progress * 100).toString() + "% completed, "; + text += timeElapsed + " secs elapsed."; + // this.initProgressCallback[j]({ + // progress: progress, + // timeElapsed: timeElapsed, + // text: text + // }); + } + }); + allEvents = Promise.all([allEvents, event_1]).then(function () { }); + }; + for (_i = 0, fmapEntries_1 = fmapEntries; _i < fmapEntries_1.length; _i++) { + _a = fmapEntries_1[_i], key = _a[0], finfo = _a[1]; + _loop_2(key, finfo); + } + return [4 /*yield*/, allEvents]; + case 1: + _b.sent(); + assert(finishCounter == fmapEntries.length); + return [2 /*return*/]; + } + }); + }); + }; + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + Instance.prototype.initWebGPU = function (device) { + var _this = this; + var webGPUContext = new WebGPUContext(this.memory, device); + this.registerFunc("wasm.WebGPUDeviceAPI", function (name) { + return webGPUContext.getDeviceAPI(name); + }); + this.registerFunc("wasm.WebGPUCreateShader", function (info, code) { + var finfo = JSON.parse(info); + return webGPUContext.createShader(finfo, code); + }); + this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", function () { return __awaiter(_this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, webGPUContext.sync()]; + case 1: + _a.sent(); + return [2 /*return*/]; + } + }); + }); }); + this.lib.webGPUContext = webGPUContext; + }; + /** Register all object factory */ + Instance.prototype.registerObjectFactoryFuncs = function () { + this.registerObjectConstructor("Array", function (handle, lib, ctx) { + return new TVMArray(handle, lib, ctx); + }); + }; + /** Register global packed functions needed by the backend to the env. */ + Instance.prototype.registerEnvGlobalPackedFuncs = function () { + var _this = this; + // Register the timer function to enable the time_evaluator. + var perf = getPerformance(); + // Helper function to time the finvoke + var timeExecution = function (finvoke, dev, nstep, repeat, minRepeatMs, limitZeroTimeIterations, cooldownIntervalMs, repeatsToCooldown) { return __awaiter(_this, void 0, void 0, function () { + var result, setupNumber, i, durationMs, absoluteZeroTimes, golden_ratio, tstart, tend, speed, ret; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + // detach and explicit dispose when tasks is fullfilled + // the promise will immediately return and we need to makesure + // finvoke do not get recycled. + this.ctx.detachFromCurrentScope(finvoke); + finvoke(this.scalar(1, "int32")); + return [4 /*yield*/, dev.sync()]; + case 1: + _a.sent(); + result = []; + setupNumber = nstep; + i = 0; + _a.label = 2; + case 2: + if (!(i < repeat)) return [3 /*break*/, 9]; + durationMs = 0.0; + absoluteZeroTimes = 0; + _a.label = 3; + case 3: + if (durationMs > 0.0) { + golden_ratio = 1.618; + setupNumber = Math.floor(Math.max(minRepeatMs / (durationMs / setupNumber) + 1, setupNumber * golden_ratio)); + } + tstart = perf.now(); + finvoke(this.scalar(setupNumber, "int32")); + return [4 /*yield*/, dev.sync()]; + case 4: + _a.sent(); + tend = perf.now(); + durationMs = tend - tstart; + if (durationMs == 0) { + absoluteZeroTimes++; + } + _a.label = 5; + case 5: + if (durationMs < minRepeatMs && absoluteZeroTimes < limitZeroTimeIterations) return [3 /*break*/, 3]; + _a.label = 6; + case 6: + speed = durationMs / setupNumber / 1000; + result.push(speed); + if (!(cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0)) return [3 /*break*/, 8]; + return [4 /*yield*/, new Promise(function (r) { return setTimeout(r, cooldownIntervalMs); })]; + case 7: + _a.sent(); + _a.label = 8; + case 8: + ++i; + return [3 /*break*/, 2]; + case 9: + ret = new Float64Array(result.length); + ret.set(result); + // dispose finvoke + finvoke.dispose(); + return [2 /*return*/, new Uint8Array(ret.buffer)]; + } + }); + }); }; + var addOne = function (x) { return __awaiter(_this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, new Promise(function (resolve) { return setTimeout(resolve, 100); })]; + case 1: + _a.sent(); + return [2 /*return*/, x + 1]; + } + }); + }); }; + this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); + this.registerAsyncServerFunc("testing.asyncAddOne", addOne); + }; + Instance.prototype.createPackedFuncFromCFunc = function (func) { + var findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop(); + } + else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + var stack = this.lib.getOrAllocCallStack(); + var outOffset = stack.allocPtrArray(1); + var outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall(this.exports + .TVMWasmFuncCreateFromCFunc(findex, outPtr)); + var ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + }; + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + Instance.prototype.setPackedArguments = function (stack, args, argsValue, argsCode) { + for (var i = 0; i < args.length; ++i) { + var val = args[i]; + var tp = typeof val; + var valueOffset = argsValue + i * SizeOf.TVMValue; + var codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + if (!val.isView) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + } + else { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle); + } + } + else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Int); + } + else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Float); + } + else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); + } + } + else if (val instanceof DLDevice) { + stack.storeI32(valueOffset, val.deviceType); + stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); + stack.storeI32(codeOffset, ArgTypeCode.DLDevice); + } + else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } + else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } + else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, ArgTypeCode.Null); + } + else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMStr); + } + else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); + } + else if (val instanceof Function) { + val = this.toPackedFuncInternal(val, false); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } + else if (val instanceof Module) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + } + else if (val instanceof TVMObject) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle); + } + else { + throw new Error("Unsupported argument type " + tp); + } + } + }; + Instance.prototype.wrapJSFuncAsPackedCFunc = function (func) { + var _this = this; + var lib = this.lib; + return function (argValues, argCodes, nargs, ret, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle) { + var jsArgs = []; + // use scope to track js values. + _this.ctx.beginScope(); + for (var i = 0; i < nargs; ++i) { + var valuePtr = argValues + i * SizeOf.TVMValue; + var codePtr = argCodes + i * SizeOf.I32; + var tcode = lib.memory.loadI32(codePtr); + if (tcode == ArgTypeCode.TVMObjectHandle || + tcode == ArgTypeCode.TVMObjectRValueRefArg || + tcode == ArgTypeCode.TVMPackedFuncHandle || + tcode == ArgTypeCode.TVMNDArrayHandle || + tcode == ArgTypeCode.TVMModuleHandle) { + lib.checkCall(lib.exports.TVMCbArgToReturn(valuePtr, codePtr)); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(_this.retValueToJS(valuePtr, tcode, true)); + } + var rv = func.apply(void 0, jsArgs); + // recycle all js object value in function unless we want to retain them. + _this.ctx.endScope(); + if (rv !== undefined && rv !== null) { + var stack = lib.getOrAllocCallStack(); + var valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + var codeOffset = stack.allocRawBytes(SizeOf.I32); + _this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + var valuePtr = stack.ptrFromOffset(valueOffset); + var codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall(lib.exports.TVMCFuncSetReturn(ret, valuePtr, codePtr, 1)); + lib.recycleCallStack(stack); + } + return 0; + }; + }; + Instance.prototype.makePackedFunc = function (handle) { + var _this = this; + var cell = new PackedFuncCell(handle, this.lib); + var packedFunc = function () { + var args = []; + for (var _i = 0; _i < arguments.length; _i++) { + args[_i] = arguments[_i]; + } + var stack = _this.lib.getOrAllocCallStack(); + var valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + var tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + _this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + var rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + var rcodeOffset = stack.allocRawBytes(SizeOf.I32); + var rvaluePtr = stack.ptrFromOffset(rvalueOffset); + var rcodePtr = stack.ptrFromOffset(rcodeOffset); + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + _this.lib.checkCall(_this.exports.TVMFuncCall(cell.getHandle(), stack.ptrFromOffset(valueOffset), stack.ptrFromOffset(tcodeOffset), args.length, rvaluePtr, rcodePtr)); + var ret = _this.retValueToJS(rvaluePtr, _this.memory.loadI32(rcodePtr), false); + _this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + var ret = packedFunc; + ret.dispose = function () { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret; + }; + /** + * Creaye return value of the packed func. The value us auto-tracked for dispose. + * @param rvaluePtr The location of rvalue + * @param tcode The type code. + * @param callbackArg Whether it is being used in callbackArg. + * @returns The JS value. + */ + Instance.prototype.retValueToJS = function (rvaluePtr, tcode, callbackArg) { + var _this = this; + switch (tcode) { + case ArgTypeCode.Int: + case ArgTypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case ArgTypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case ArgTypeCode.TVMOpaqueHandle: { + return this.memory.loadPointer(rvaluePtr); + } + case ArgTypeCode.TVMNDArrayHandle: { + return this.ctx.attachToCurrentScope(new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx)); + } + case ArgTypeCode.TVMDLTensorHandle: { + assert(callbackArg); + // no need to attach as we are only looking at view + return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib, this.ctx); + } + case ArgTypeCode.TVMPackedFuncHandle: { + return this.ctx.attachToCurrentScope(this.makePackedFunc(this.memory.loadPointer(rvaluePtr))); + } + case ArgTypeCode.TVMModuleHandle: { + return this.ctx.attachToCurrentScope(new Module(this.memory.loadPointer(rvaluePtr), this.lib, function (ptr) { + return _this.ctx.attachToCurrentScope(_this.makePackedFunc(ptr)); + })); + } + case ArgTypeCode.TVMObjectHandle: { + var obj = new TVMObject(this.memory.loadPointer(rvaluePtr), this.lib, this.ctx); + var func = this.objFactory.get(obj.typeIndex()); + if (func != undefined) { + return this.ctx.attachToCurrentScope(func(obj.getHandle(), this.lib, this.ctx)); + } + else { + return this.ctx.attachToCurrentScope(obj); + } + } + case ArgTypeCode.Null: return undefined; + case ArgTypeCode.DLDevice: { + var deviceType = this.memory.loadI32(rvaluePtr); + var deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); + return this.device(deviceType, deviceId); + } + case ArgTypeCode.TVMStr: { + var ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + return ret; + } + case ArgTypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + }; + return Instance; +}()); +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +function instantiate(bufferSource, importObject, logger) { + if (importObject === void 0) { importObject = {}; } + if (logger === void 0) { logger = console.log; } + var env = new Environment(importObject, logger); + return WebAssembly.instantiate(bufferSource, env.imports).then(function (result) { + return new Instance(result.module, {}, result.instance, env); + }); +} + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +var RPCServerState; +(function (RPCServerState) { + RPCServerState[RPCServerState["InitHeader"] = 0] = "InitHeader"; + RPCServerState[RPCServerState["InitHeaderKey"] = 1] = "InitHeaderKey"; + RPCServerState[RPCServerState["InitServer"] = 2] = "InitServer"; + RPCServerState[RPCServerState["WaitForCallback"] = 3] = "WaitForCallback"; + RPCServerState[RPCServerState["ReceivePacketHeader"] = 4] = "ReceivePacketHeader"; + RPCServerState[RPCServerState["ReceivePacketBody"] = 5] = "ReceivePacketBody"; +})(RPCServerState || (RPCServerState = {})); + +/** + * @license + * Copyright 2019 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ +const proxyMarker = Symbol("Comlink.proxy"); +const createEndpoint = Symbol("Comlink.endpoint"); +const releaseProxy = Symbol("Comlink.releaseProxy"); +const finalizer = Symbol("Comlink.finalizer"); +const throwMarker = Symbol("Comlink.thrown"); +const isObject = val => typeof val === "object" && val !== null || typeof val === "function"; +/** + * Internal transfer handle to handle objects marked to proxy. + */ +const proxyTransferHandler = { + canHandle: val => isObject(val) && val[proxyMarker], + serialize(obj) { + const { + port1, + port2 + } = new MessageChannel(); + expose(obj, port1); + return [port2, [port2]]; + }, + deserialize(port) { + port.start(); + return wrap(port); + } +}; +/** + * Internal transfer handler to handle thrown exceptions. + */ +const throwTransferHandler = { + canHandle: value => isObject(value) && throwMarker in value, + serialize({ + value + }) { + let serialized; + if (value instanceof Error) { + serialized = { + isError: true, + value: { + message: value.message, + name: value.name, + stack: value.stack + } + }; + } else { + serialized = { + isError: false, + value + }; + } + return [serialized, []]; + }, + deserialize(serialized) { + if (serialized.isError) { + throw Object.assign(new Error(serialized.value.message), serialized.value); + } + throw serialized.value; + } +}; +/** + * Allows customizing the serialization of certain values. + */ +const transferHandlers = new Map([["proxy", proxyTransferHandler], ["throw", throwTransferHandler]]); +function isAllowedOrigin(allowedOrigins, origin) { + for (const allowedOrigin of allowedOrigins) { + if (origin === allowedOrigin || allowedOrigin === "*") { + return true; + } + if (allowedOrigin instanceof RegExp && allowedOrigin.test(origin)) { + return true; + } + } + return false; +} +function expose(obj, ep = globalThis, allowedOrigins = ["*"]) { + ep.addEventListener("message", function callback(ev) { + if (!ev || !ev.data) { + return; + } + if (!isAllowedOrigin(allowedOrigins, ev.origin)) { + console.warn(`Invalid origin '${ev.origin}' for comlink proxy`); + return; + } + const { + id, + type, + path + } = Object.assign({ + path: [] + }, ev.data); + const argumentList = (ev.data.argumentList || []).map(fromWireValue); + let returnValue; + try { + const parent = path.slice(0, -1).reduce((obj, prop) => obj[prop], obj); + const rawValue = path.reduce((obj, prop) => obj[prop], obj); + switch (type) { + case "GET" /* MessageType.GET */: + { + returnValue = rawValue; + } + break; + case "SET" /* MessageType.SET */: + { + parent[path.slice(-1)[0]] = fromWireValue(ev.data.value); + returnValue = true; + } + break; + case "APPLY" /* MessageType.APPLY */: + { + returnValue = rawValue.apply(parent, argumentList); + } + break; + case "CONSTRUCT" /* MessageType.CONSTRUCT */: + { + const value = new rawValue(...argumentList); + returnValue = proxy(value); + } + break; + case "ENDPOINT" /* MessageType.ENDPOINT */: + { + const { + port1, + port2 + } = new MessageChannel(); + expose(obj, port2); + returnValue = transfer(port1, [port1]); + } + break; + case "RELEASE" /* MessageType.RELEASE */: + { + returnValue = undefined; + } + break; + default: + return; + } + } catch (value) { + returnValue = { + value, + [throwMarker]: 0 + }; + } + Promise.resolve(returnValue).catch(value => { + return { + value, + [throwMarker]: 0 + }; + }).then(returnValue => { + const [wireValue, transferables] = toWireValue(returnValue); + ep.postMessage(Object.assign(Object.assign({}, wireValue), { + id + }), transferables); + if (type === "RELEASE" /* MessageType.RELEASE */) { + // detach and deactive after sending release response above. + ep.removeEventListener("message", callback); + closeEndPoint(ep); + if (finalizer in obj && typeof obj[finalizer] === "function") { + obj[finalizer](); + } + } + }).catch(error => { + // Send Serialization Error To Caller + const [wireValue, transferables] = toWireValue({ + value: new TypeError("Unserializable return value"), + [throwMarker]: 0 + }); + ep.postMessage(Object.assign(Object.assign({}, wireValue), { + id + }), transferables); + }); + }); + if (ep.start) { + ep.start(); + } +} +function isMessagePort(endpoint) { + return endpoint.constructor.name === "MessagePort"; +} +function closeEndPoint(endpoint) { + if (isMessagePort(endpoint)) endpoint.close(); +} +function wrap(ep, target) { + return createProxy(ep, [], target); +} +function throwIfProxyReleased(isReleased) { + if (isReleased) { + throw new Error("Proxy has been released and is not useable"); + } +} +function releaseEndpoint(ep) { + return requestResponseMessage(ep, { + type: "RELEASE" /* MessageType.RELEASE */ + }).then(() => { + closeEndPoint(ep); + }); +} +const proxyCounter = new WeakMap(); +const proxyFinalizers = "FinalizationRegistry" in globalThis && new FinalizationRegistry(ep => { + const newCount = (proxyCounter.get(ep) || 0) - 1; + proxyCounter.set(ep, newCount); + if (newCount === 0) { + releaseEndpoint(ep); + } +}); +function registerProxy(proxy, ep) { + const newCount = (proxyCounter.get(ep) || 0) + 1; + proxyCounter.set(ep, newCount); + if (proxyFinalizers) { + proxyFinalizers.register(proxy, ep, proxy); + } +} +function unregisterProxy(proxy) { + if (proxyFinalizers) { + proxyFinalizers.unregister(proxy); + } +} +function createProxy(ep, path = [], target = function () {}) { + let isProxyReleased = false; + const proxy = new Proxy(target, { + get(_target, prop) { + throwIfProxyReleased(isProxyReleased); + if (prop === releaseProxy) { + return () => { + unregisterProxy(proxy); + releaseEndpoint(ep); + isProxyReleased = true; + }; + } + if (prop === "then") { + if (path.length === 0) { + return { + then: () => proxy + }; + } + const r = requestResponseMessage(ep, { + type: "GET" /* MessageType.GET */, + path: path.map(p => p.toString()) + }).then(fromWireValue); + return r.then.bind(r); + } + return createProxy(ep, [...path, prop]); + }, + set(_target, prop, rawValue) { + throwIfProxyReleased(isProxyReleased); + // FIXME: ES6 Proxy Handler `set` methods are supposed to return a + // boolean. To show good will, we return true asynchronously ¯\_(ツ)_/¯ + const [value, transferables] = toWireValue(rawValue); + return requestResponseMessage(ep, { + type: "SET" /* MessageType.SET */, + path: [...path, prop].map(p => p.toString()), + value + }, transferables).then(fromWireValue); + }, + apply(_target, _thisArg, rawArgumentList) { + throwIfProxyReleased(isProxyReleased); + const last = path[path.length - 1]; + if (last === createEndpoint) { + return requestResponseMessage(ep, { + type: "ENDPOINT" /* MessageType.ENDPOINT */ + }).then(fromWireValue); + } + // We just pretend that `bind()` didn’t happen. + if (last === "bind") { + return createProxy(ep, path.slice(0, -1)); + } + const [argumentList, transferables] = processArguments(rawArgumentList); + return requestResponseMessage(ep, { + type: "APPLY" /* MessageType.APPLY */, + path: path.map(p => p.toString()), + argumentList + }, transferables).then(fromWireValue); + }, + construct(_target, rawArgumentList) { + throwIfProxyReleased(isProxyReleased); + const [argumentList, transferables] = processArguments(rawArgumentList); + return requestResponseMessage(ep, { + type: "CONSTRUCT" /* MessageType.CONSTRUCT */, + path: path.map(p => p.toString()), + argumentList + }, transferables).then(fromWireValue); + } + }); + registerProxy(proxy, ep); + return proxy; +} +function myFlat(arr) { + return Array.prototype.concat.apply([], arr); +} +function processArguments(argumentList) { + const processed = argumentList.map(toWireValue); + return [processed.map(v => v[0]), myFlat(processed.map(v => v[1]))]; +} +const transferCache = new WeakMap(); +function transfer(obj, transfers) { + transferCache.set(obj, transfers); + return obj; +} +function proxy(obj) { + return Object.assign(obj, { + [proxyMarker]: true + }); +} +function toWireValue(value) { + for (const [name, handler] of transferHandlers) { + if (handler.canHandle(value)) { + const [serializedValue, transferables] = handler.serialize(value); + return [{ + type: "HANDLER" /* WireValueType.HANDLER */, + name, + value: serializedValue + }, transferables]; + } + } + return [{ + type: "RAW" /* WireValueType.RAW */, + value + }, transferCache.get(value) || []]; +} +function fromWireValue(value) { + switch (value.type) { + case "HANDLER" /* WireValueType.HANDLER */: + return transferHandlers.get(value.name).deserialize(value.value); + case "RAW" /* WireValueType.RAW */: + return value.value; + } +} +function requestResponseMessage(ep, msg, transfers) { + return new Promise(resolve => { + const id = generateUUID(); + ep.addEventListener("message", function l(ev) { + if (!ev.data || !ev.data.id || ev.data.id !== id) { + return; + } + ep.removeEventListener("message", l); + resolve(ev.data); + }); + if (ep.start) { + ep.start(); + } + ep.postMessage(Object.assign({ + id + }, msg), transfers); + }); +} +function generateUUID() { + return new Array(4).fill(0).map(() => Math.floor(Math.random() * Number.MAX_SAFE_INTEGER).toString(16)).join("-"); +} + +// Unique ID creation requires a high quality random # generator. In the browser we therefore +// require the crypto API and do not support built-in fallback to lower quality random number +// generators (like Math.random()). +let getRandomValues; +const rnds8 = new Uint8Array(16); +function rng() { + // lazy load so that environments that need to polyfill have a chance to do so + if (!getRandomValues) { + // getRandomValues needs to be invoked in a context where "this" is a Crypto implementation. + getRandomValues = typeof crypto !== 'undefined' && crypto.getRandomValues && crypto.getRandomValues.bind(crypto); + if (!getRandomValues) { + throw new Error('crypto.getRandomValues() not supported. See https://github.com/uuidjs/uuid#getrandomvalues-not-supported'); + } + } + return getRandomValues(rnds8); +} + +/** + * Convert array of 16 byte values to UUID string format of the form: + * XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX + */ + +const byteToHex = []; +for (let i = 0; i < 256; ++i) { + byteToHex.push((i + 0x100).toString(16).slice(1)); +} +function unsafeStringify(arr, offset = 0) { + // Note: Be careful editing this code! It's been tuned for performance + // and works in ways you may not expect. See https://github.com/uuidjs/uuid/pull/434 + return (byteToHex[arr[offset + 0]] + byteToHex[arr[offset + 1]] + byteToHex[arr[offset + 2]] + byteToHex[arr[offset + 3]] + '-' + byteToHex[arr[offset + 4]] + byteToHex[arr[offset + 5]] + '-' + byteToHex[arr[offset + 6]] + byteToHex[arr[offset + 7]] + '-' + byteToHex[arr[offset + 8]] + byteToHex[arr[offset + 9]] + '-' + byteToHex[arr[offset + 10]] + byteToHex[arr[offset + 11]] + byteToHex[arr[offset + 12]] + byteToHex[arr[offset + 13]] + byteToHex[arr[offset + 14]] + byteToHex[arr[offset + 15]]).toLowerCase(); +} + +const randomUUID = typeof crypto !== 'undefined' && crypto.randomUUID && crypto.randomUUID.bind(crypto); +var native = { + randomUUID +}; + +function v4(options, buf, offset) { + if (native.randomUUID && !buf && !options) { + return native.randomUUID(); + } + options = options || {}; + const rnds = options.random || (options.rng || rng)(); // Per 4.4, set bits for version and `clock_seq_hi_and_reserved` + + rnds[6] = rnds[6] & 0x0f | 0x40; + rnds[8] = rnds[8] & 0x3f | 0x80; // Copy bytes to buffer, if provided + + if (buf) { + offset = offset || 0; + for (let i = 0; i < 16; ++i) { + buf[offset + i] = rnds[i]; + } + return buf; + } + return unsafeStringify(rnds); +} + +export { __spreadArray as _, __assign as a, __awaiter as b, __generator as c, detectGPUDevice as d, expose as e, instantiate as i, proxy as p, v4 as v, wrap as w }; diff --git a/packages/headless/dist/worker-cc79b531.js b/packages/headless/dist/worker-cc79b531.js new file mode 100644 index 0000000..d89f262 --- /dev/null +++ b/packages/headless/dist/worker-cc79b531.js @@ -0,0 +1,412 @@ +import { b as __awaiter, c as __generator, d as detectGPUDevice, i as instantiate, v as v4, e as expose } from './v4-2119d9d5.js'; + +var LLMInstance = /** @class */ (function () { + function LLMInstance(config, sentencePieceProcessor) { + this.config = config; + this.tvm = undefined; + this.tokenizer = undefined; + this.model = undefined; + this.spp = sentencePieceProcessor; + this.processing = false; + } + LLMInstance.prototype.isInitialized = function () { + return this.model != undefined; + }; + LLMInstance.prototype.init = function (cb) { + return __awaiter(this, void 0, void 0, function () { + var wasmSource, _a, output, err_1, _b; + var _this = this; + return __generator(this, function (_c) { + switch (_c.label) { + case 0: + if (this.model) { + return [2 /*return*/]; + } + return [4 /*yield*/, fetch(this.config.wasmUrl)]; + case 1: return [4 /*yield*/, (_c.sent()).arrayBuffer()]; + case 2: + wasmSource = _c.sent(); + _a = this; + return [4 /*yield*/, instantiate(new Uint8Array(wasmSource), + //@ts-ignore + new EmccWASI(), console.log)]; + case 3: + _a.tvm = _c.sent(); + _c.label = 4; + case 4: + _c.trys.push([4, 6, , 7]); + return [4 /*yield*/, detectGPUDevice()]; + case 5: + output = _c.sent(); + if (output !== undefined) { + this.tvm.initWebGPU(output.device); + } + else { + throw Error("This browser env do not support WebGPU"); + } + return [3 /*break*/, 7]; + case 6: + err_1 = _c.sent(); + throw Error("Find an error initializing WebGPU: " + err_1.toString()); + case 7: + this.tvm.registerInitProgressCallback(cb); + return [4 /*yield*/, this.tvm.fetchNDArrayCache(this.config.cacheUrl, this.tvm.webgpu())]; + case 8: + _c.sent(); + _b = this; + return [4 /*yield*/, this.spp()(this.config.tokenizerUrl)]; + case 9: + _b.tokenizer = _c.sent(); + this.model = this.tvm.withNewScope(function () { + return new LLMInstanceScope(_this.tvm, _this.tokenizer, _this.config.maxWindowSize); + }); + return [2 /*return*/, this.model.init()]; + } + }); + }); + }; + LLMInstance.prototype.generate = function (request, cb) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (this.processing) { + return [2 /*return*/]; + } + this.processing = true; + return [4 /*yield*/, this.model.generate(request, cb)]; + case 1: + _a.sent(); + this.processing = false; + return [2 /*return*/]; + } + }); + }); + }; + return LLMInstance; +}()); +var LLMInstanceScope = /** @class */ (function () { + function LLMInstanceScope(tvm, tokenizer, maxWindowSize) { + if (maxWindowSize === void 0) { maxWindowSize = 2048; } + this.tvm = tvm; + this.tokenizer = tokenizer; + this.bosTokenId = 1; + this.eosTokenId = 2; + this.maxWindowSize = maxWindowSize; + this.device = this.tvm.webgpu(); + this.vm = this.tvm.detachFromCurrentScope(this.tvm.createVirtualMachine(this.device)); + this.encoding = this.tvm.detachFromCurrentScope(this.vm.getFunction("encoding")); + this.decoding = this.tvm.detachFromCurrentScope(this.vm.getFunction("decoding")); + this.params = this.tvm.detachFromCurrentScope(this.tvm.getParamsFromCache("param", this.tvm.cacheMetadata.ParamSize)); + var fcreateCache = this.vm.getFunction("create_kv_cache"); + this.fclearKVCaches = this.tvm.detachFromCurrentScope(this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")); + // use extern config for now + this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); + // fill with pad token + this.logitsOnCPU = undefined; + this.kvCacheLength = 0; + this.lastMessageId = ""; + } + LLMInstanceScope.prototype.init = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule())]; + case 1: + _a.sent(); + return [2 /*return*/]; + } + }); + }); + }; + LLMInstanceScope.prototype.getTokensFromStart = function (conversation, maxTokens) { + return __awaiter(this, void 0, void 0, function () { + var tokens, i, message, text, messageTokens, _a, _b, _c, _d, _e, _f; + return __generator(this, function (_g) { + switch (_g.label) { + case 0: + this.clearKVCache(); + tokens = []; + i = conversation.messages.length - 1; + _g.label = 1; + case 1: + if (!(i >= 0)) return [3 /*break*/, 5]; + message = conversation.messages[i]; + text = "".concat(message.role, ": ").concat(message.text, "\n"); + return [4 /*yield*/, this.tokenizer.encodeIds(text)]; + case 2: + messageTokens = _g.sent(); + if (tokens.length + messageTokens.length + maxTokens > + this.maxWindowSize) { + return [3 /*break*/, 5]; + } + _b = (_a = tokens.unshift).apply; + _c = [tokens]; + return [4 /*yield*/, this.tokenizer.encodeIds(text)]; + case 3: + _b.apply(_a, _c.concat([(_g.sent())])); + _g.label = 4; + case 4: + i--; + return [3 /*break*/, 1]; + case 5: + _e = (_d = tokens.unshift).apply; + _f = [tokens]; + return [4 /*yield*/, this.tokenizer.encodeIds(conversation.systemPrompt)]; + case 6: + _e.apply(_d, _f.concat([(_g.sent())])); + tokens.unshift(this.bosTokenId); + return [2 /*return*/, tokens]; + } + }); + }); + }; + LLMInstanceScope.prototype.getTokens = function (conversation, maxTokens) { + return __awaiter(this, void 0, void 0, function () { + var startMsgIdx, i, tokens, i, message, text, messageTokens, _a, _b, _c; + return __generator(this, function (_d) { + switch (_d.label) { + case 0: + if (!(this.kvCacheLength == 0)) return [3 /*break*/, 2]; + return [4 /*yield*/, this.getTokensFromStart(conversation, maxTokens)]; + case 1: + // Case 1 + return [2 /*return*/, _d.sent()]; + case 2: + startMsgIdx = 0; + for (i = conversation.messages.length - 1; i >= 0; i--) { + if (conversation.messages[i].id == this.lastMessageId) { + startMsgIdx = i + 1; + break; + } + } + if (!(startMsgIdx == 0)) return [3 /*break*/, 4]; + return [4 /*yield*/, this.getTokensFromStart(conversation, maxTokens)]; + case 3: + // Case 2 + return [2 /*return*/, _d.sent()]; + case 4: + tokens = [this.eosTokenId]; + i = startMsgIdx; + _d.label = 5; + case 5: + if (!(i < conversation.messages.length)) return [3 /*break*/, 11]; + message = conversation.messages[i]; + text = "".concat(message.role, ": ").concat(message.text); + return [4 /*yield*/, this.tokenizer.encodeIds(text)]; + case 6: + messageTokens = _d.sent(); + if (!(tokens.length + messageTokens.length + maxTokens > + this.maxWindowSize)) return [3 /*break*/, 8]; + return [4 /*yield*/, this.getTokensFromStart(conversation, maxTokens)]; + case 7: + // Case 4 + return [2 /*return*/, _d.sent()]; + case 8: + _b = (_a = tokens.push).apply; + _c = [tokens]; + return [4 /*yield*/, this.tokenizer.encodeIds(text)]; + case 9: + _b.apply(_a, _c.concat([(_d.sent())])); + _d.label = 10; + case 10: + i++; + return [3 /*break*/, 5]; + case 11: + // Case 3 + return [2 /*return*/, tokens]; + } + }); + }); + }; + LLMInstanceScope.prototype.generate = function (request, cb) { + return __awaiter(this, void 0, void 0, function () { + var conversation, maxTokens, assistantRoleName, stopTexts, tokens, _a, _b, _c, _d, _e, _f, inputTokenLength, outputText, tstart, tend, step, id, input, logits, nextToken, outputTokens, stopPos, stop_1, i; + return __generator(this, function (_g) { + switch (_g.label) { + case 0: + conversation = request.conversation, maxTokens = request.maxTokens, assistantRoleName = request.assistantRoleName, stopTexts = request.stopTexts; + return [4 /*yield*/, this.getTokens(conversation, maxTokens)]; + case 1: + tokens = _g.sent(); + _b = (_a = tokens.push).apply; + _c = [tokens]; + return [4 /*yield*/, this.tokenizer.encodeIds("".concat(assistantRoleName, ":"))]; + case 2: + _b.apply(_a, _c.concat([(_g.sent())])); + _e = (_d = console).log; + _f = ["debug: "]; + return [4 /*yield*/, this.tokenizer.decodeIds(tokens)]; + case 3: + _e.apply(_d, _f.concat([_g.sent()])); + inputTokenLength = tokens.length; + outputText = ""; + tstart = 0, tend = 0, step = 0; + id = v4(); + _g.label = 4; + case 4: + if (!(step < maxTokens)) return [3 /*break*/, 7]; + this.tvm.beginScope(); + tstart = performance.now(); + if (step == 0) { + input = this.tvm.empty([1, tokens.length], "int32", this.device); + input.copyFrom(tokens); + } + else { + input = this.tvm.empty([1, 1], "int32", this.device); + input.copyFrom(tokens.slice(tokens.length - 1)); + } + logits = this.tvm.detachFromCurrentScope(this.forward(input, this.kvCacheLength + inputTokenLength + step)); + this.tvm.endScope(); + return [4 /*yield*/, this.sampleTokenFromLogits(logits)]; + case 5: + nextToken = _g.sent(); + logits.dispose(); + tokens.push(nextToken); + outputTokens = tokens.slice(inputTokenLength); + outputText = this.tokenizer.decodeIds(outputTokens); + tend = performance.now(); + if (nextToken == this.eosTokenId) + return [3 /*break*/, 7]; + stopPos = outputText.lastIndexOf(""); + if (stopPos != -1) { + outputText = outputText.substring(0, stopPos); + return [3 /*break*/, 7]; + } + stop_1 = false; + for (i = 0; i < stopTexts.length; i++) { + if (outputText.endsWith(stopTexts[i])) { + outputText = outputText.substring(0, outputText.length - stopTexts[i].length); + stop_1 = true; + break; + } + } + if (stop_1) + return [3 /*break*/, 7]; + if (step != 0) { + cb({ + requestId: id, + step: step, + outputText: outputText, + stats: { + totalDecodingSeconds: (tend - tstart) / 1000, + totalDecodedTokens: tokens.length - inputTokenLength, + totalEncodedTokens: inputTokenLength, + }, + isFinished: false, + }); + } + _g.label = 6; + case 6: + step++; + return [3 /*break*/, 4]; + case 7: + this.kvCacheLength += tokens.length - 1; + this.lastMessageId = id; + cb({ + requestId: id, + outputText: outputText, + step: step, + stats: { + totalDecodingSeconds: (tend - tstart) / 1000, + totalDecodedTokens: tokens.length - inputTokenLength, + totalEncodedTokens: inputTokenLength, + }, + isFinished: true, + }); + return [2 /*return*/]; + } + }); + }); + }; + LLMInstanceScope.prototype.dispose = function () { + // note: tvm instance is not owned by this class + this.params.dispose(); + this.decoding.dispose(); + this.encoding.dispose(); + this.vm.dispose(); + this.kvCache.dispose(); + this.fclearKVCaches.dispose(); + if (this.logitsOnCPU != undefined) { + this.logitsOnCPU.dispose(); + } + }; + LLMInstanceScope.prototype.clearKVCache = function () { + this.fclearKVCaches(this.kvCache); + this.kvCacheLength = 0; + this.lastMessageId = ""; + }; + LLMInstanceScope.prototype.forward = function (inputs, curPos) { + this.tvm.beginScope(); + var retValue; + var seqLenShape = this.tvm.makeShapeTuple([curPos]); + if (inputs.shape[1] > 1) { + retValue = this.encoding(inputs, seqLenShape, this.kvCache, this.params); + } + else { + retValue = this.decoding(inputs, seqLenShape, this.kvCache, this.params); + } + var logits = this.tvm.detachFromCurrentScope(retValue.get(0)); + this.tvm.endScope(); + this.tvm.attachToCurrentScope(logits); + return logits; + }; + // NOTE: caller must call device.sync() + LLMInstanceScope.prototype.updateLogitsOnCPU = function (logits) { + if (this.logitsOnCPU == undefined) { + this.logitsOnCPU = this.tvm.detachFromCurrentScope(this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())); + } + else { + if (logits.shape[0] != this.logitsOnCPU.shape[0]) { + throw Error("We expect the size of logits to remain unchanged"); + } + } + this.logitsOnCPU.copyFrom(logits); + }; + LLMInstanceScope.prototype.sampleTokenFromLogits = function (logits, temperature, top_p) { + if (temperature === void 0) { temperature = 0.8; } + if (top_p === void 0) { top_p = 0.95; } + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + this.tvm.beginScope(); + this.updateLogitsOnCPU(logits); + this.tvm.endScope(); + return [4 /*yield*/, this.device.sync()]; + case 1: + _a.sent(); + return [2 /*return*/, this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p)]; + } + }); + }); + }; + return LLMInstanceScope; +}()); + +var config = { + kvConfig: { + numLayers: 64, + shape: [32, 32, 128], + dtype: 'float32', + }, + wasmUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/vicuna-7b-v1_webgpu.wasm', + cacheUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/params/', + tokenizerUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/tokenizer.model', + sentencePieceJsUrl: 'https://cdn.matt-rickard.com/code/sentencepiece.js', + tvmRuntimeJsUrl: 'https://cdn.matt-rickard.com/code/tvmjs_runtime.wasi.js', + maxWindowSize: 2048, +}; +var instance = new LLMInstance(config, function () { return globalThis.sentencepiece.sentencePieceProcessor; }); +var worker = { + init: function (callback) { + instance.init(callback); + }, + generate: function (request, cb) { + instance.generate(request, cb); + } +}; +importScripts.apply(void 0, [ + config.sentencePieceJsUrl, config.tvmRuntimeJsUrl +]); +expose(worker); diff --git a/packages/headless/package.json b/packages/headless/package.json new file mode 100644 index 0000000..0fcdd52 --- /dev/null +++ b/packages/headless/package.json @@ -0,0 +1,63 @@ +{ + "name": "@react-llm/headless", + "version": "0.0.3", + "author": "Matt Rickard ", + "license": "MIT", + "module": "dist/index.js", + "type": "module", + "types": "dist/types/index.d.ts", + "keywords": [ + "chatgpt", + "llm", + "headless", + "react" + ], + "scripts": { + "build": "npm run clean && rollup -c", + "dev": "rollup -c -w", + "clean": "rm -rf dist" + }, + "dependencies": { + "@babel/plugin-transform-modules-commonjs": "^7.21.5", + "@surma/rollup-plugin-off-main-thread": "^2.2.3", + "@types/node": "20.1.1", + "@types/react": "18.2.6", + "@types/react-dom": "18.2.4", + "autoprefixer": "10.4.14", + "comlink": "^4.4.1", + "eslint": "8.40.0", + "eslint-config-next": "13.4.1", + "react": "18.2.0", + "react-dom": "18.2.0", + "typescript": "5.0.4", + "uuid": "^9.0.0", + "zustand": "^4.3.8" + }, + "devDependencies": { + "@babel/cli": "^7.21.5", + "@babel/core": "^7.21.8", + "@babel/preset-env": "^7.21.5", + "@babel/preset-react": "^7.18.6", + "@babel/preset-typescript": "^7.21.5", + "@rollup/plugin-alias": "^5.0.0", + "@rollup/plugin-babel": "^6.0.3", + "@rollup/plugin-commonjs": "^24.1.0", + "@rollup/plugin-json": "^6.0.0", + "@rollup/plugin-node-resolve": "^15.0.2", + "@types/uuid": "^9.0.1", + "@webgpu/types": "^0.1.32", + "@zerollup/ts-transform-paths": "^1.7.18", + "babel-plugin-module-resolver": "^5.0.0", + "rollup": "^3.21.6", + "rollup-plugin-dts": "^5.3.0", + "rollup-plugin-node-resolve": "^5.2.0", + "rollup-plugin-terser": "^7.0.2", + "rollup-plugin-tsconfig-paths": "^1.5.0", + "rollup-plugin-typescript2": "^0.34.1", + "rollup-plugin-web-worker-loader": "^1.6.1", + "tsconfig-paths-webpack-plugin": "^4.0.1", + "ttypescript": "^1.5.15", + "webpack": "^5.82.1", + "webpack-cli": "^5.1.1" + } +} diff --git a/packages/headless/rollup.config.js b/packages/headless/rollup.config.js new file mode 100644 index 0000000..8b376f3 --- /dev/null +++ b/packages/headless/rollup.config.js @@ -0,0 +1,35 @@ +import babel from "@rollup/plugin-babel"; +import commonjs from "@rollup/plugin-commonjs"; +import json from "@rollup/plugin-json"; +import resolve, { nodeResolve } from "@rollup/plugin-node-resolve"; +import OMT from "@surma/rollup-plugin-off-main-thread"; +import typescript from "rollup-plugin-typescript2"; + +export default [ + { + input: "src/index.ts", + output: { + dir: "dist", + format: "esm", + sourceMap: true, + }, + external: ["react", "react-dom"], + plugins: [ + babel({ + exclude: "node_modules/**", + presets: ["@babel/preset-react", "@babel/preset-typescript"], + babelHelpers: "bundled", + }), + nodeResolve(), + typescript({ + tsconfig: "tsconfig.json", + sourceMap: false, + useTsconfigDeclarationDir: true, + }), + OMT(), + commonjs(), + resolve({ preferBuiltins: true }), + json(), + ], + }, +]; diff --git a/packages/headless/src/hooks/useConversationStore.tsx b/packages/headless/src/hooks/useConversationStore.tsx new file mode 100644 index 0000000..bd066a6 --- /dev/null +++ b/packages/headless/src/hooks/useConversationStore.tsx @@ -0,0 +1,178 @@ +import { v4 as uuidv4 } from "uuid"; +import { create } from "zustand"; +import { persist } from "zustand/middleware"; +import { Conversation, Message } from "../types/chat"; + +export interface ConversationStore { + conversations: Conversation[]; + currentConversationId: string; + setConversationId: (conversationId: string) => void; + + addMessage: (conversationId: string, message: Message) => void; + getConversation: (conversationId: string) => Conversation | undefined; + + setConversationTitle: (conversationId: string, title: string) => void; + + getAllConversations: () => Conversation[]; + deleteMessages: (conversationId: string) => void; + + deleteConversation: (conversationId: string) => void; + createConversation: (conversation: Conversation) => void; + deleteAllConversations: () => void; +} + +export const defaultSystemPrompt = + "A chat between a curious user and a AI chatbot named SmartestChild on AIM who responds with lowercase, frequent emojis, and 2000s internet abbreviations."; + +const useConversationStore = create()( + persist( + (set, get) => { + const initialConversation = { + id: uuidv4(), + title: "Untitled", + updatedAt: new Date().getTime(), + systemPrompt: defaultSystemPrompt, + createdAt: new Date().getTime(), + messages: [] as Message[], + }; + + return { + conversations: [initialConversation], + currentConversationId: initialConversation.id, + createConversation: (conversation: Conversation) => { + set((state) => { + return { + currentConversationId: conversation.id, + conversations: [...state.conversations, conversation], + }; + }); + }, + setConversationTitle(conversationId, title) { + set((state) => { + const conversation = state.conversations.find( + (c) => c.id === conversationId + ); + if (!conversation) { + return state; + } + return { + conversations: [ + ...state.conversations.filter((c) => c.id !== conversationId), + { + ...conversation, + title, + }, + ], + }; + }); + }, + deleteConversation(conversationId: string) { + set((state) => { + return { + conversations: state.conversations.filter( + (c) => c.id !== conversationId + ), + }; + }); + }, + setConversationId: (conversationId: string) => { + const conversationExists = get().conversations.some( + (c) => c.id === conversationId + ); + if (!conversationExists) { + throw new Error("Invalid conversation id"); + } + + set((state) => { + return { + ...state, + currentConversationId: conversationId, + }; + }); + }, + deleteAllConversations: () => { + set((state) => { + return { + conversations: [], + }; + }); + }, + deleteMessages: (conversationId) => { + set((state) => { + const conversation = state.conversations.find( + (c) => c.id === conversationId + ); + if (!conversation) { + return state; + } + return { + conversations: [ + ...state.conversations.filter((c) => c.id !== conversationId), + { + ...conversation, + updatedAt: new Date().getTime(), + messages: [], + }, + ], + }; + }); + }, + getConversation(conversationId) { + return get().conversations.find((c) => c.id === conversationId); + }, + getAllConversations() { + return get().conversations; + }, + addMessage: (conversationId, message) => { + set((state) => { + const conversation = state.conversations.find( + (c) => c.id === conversationId + ); + if (!conversation) { + return state; + } + const existingMessage = conversation.messages.find( + (m) => m.id === message.id + ); + if (existingMessage) { + // Update message + return { + conversations: [ + ...state.conversations.filter((c) => c.id !== conversationId), + { + ...conversation, + updatedAt: new Date().getTime(), + messages: [ + ...conversation.messages.filter( + (m) => m.id !== message.id + ), + message, + ], + }, + ], + }; + } + // Add message + return { + conversations: [ + ...state.conversations.filter((c) => c.id !== conversationId), + { + ...conversation, + updatedAt: new Date().getTime(), + messages: [...conversation.messages, message], + }, + ], + }; + }); + }, + }; + }, + + { + name: "chat-store", + getStorage: () => sessionStorage, + } + ) +); + +export default useConversationStore; diff --git a/packages/headless/src/hooks/useLLM.tsx b/packages/headless/src/hooks/useLLM.tsx new file mode 100644 index 0000000..c66c58a --- /dev/null +++ b/packages/headless/src/hooks/useLLM.tsx @@ -0,0 +1,257 @@ +import { detectGPUDevice } from "@/worker/lib/tvm"; +import { InitProgressReport } from "@/worker/lib/tvm/runtime"; +import * as Comlink from "comlink"; +import { Remote } from "comlink"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { v4 as uuidv4 } from "uuid"; +import { Conversation } from "../types/chat"; +import { + GenerateTextRequest, + GenerateTextResponse, + ModelWorker, +} from "../types/worker_message"; +import useConversationStore, { + defaultSystemPrompt, +} from "./useConversationStore"; +import useStore from "./useStore"; + +export type UseLLMParams = { + autoInit?: boolean; +}; + +const initialProgress = { + type: "init" as const, + progress: 0, + timeElapsed: 0, + currentChunk: 0, + totalChunks: 0, + fetchedBytes: 0, + totalBytes: 0, +}; + +export type GPUDeviceInfo = { + adapter: GPUAdapter | null; + device: GPUDevice | null; + adapterInfo: GPUAdapterInfo | null; + checked: boolean; + unsupportedReason: string | null; +}; + +export type UseLLMResponse = { + // Conversation returns the current conversation object. + conversation: Conversation | undefined; + + // AllConversations returns all conversations sorted by updatedAt. + allConversations: Conversation[] | undefined; + + // LoadingStatus returns the current loading status. + loadingStatus: InitProgressReport; + + // IsGenerating returns whether the model is currently generating. Concurrent generation is not supported. + isGenerating: boolean; + + // CreateConversation creates a new conversation and sets it as the current conversation. + createConversation: (title?: string, prompt?: string) => void; + + // SetConversationId sets the current conversation id. + setConversationId: (conversationId: string) => void; + + // DeleteConversation deletes a conversation. + deleteConversation: (conversationId: string) => void; + + // DeleteAllConversations deletes all conversations. + deleteAllConversations: () => void; + + // DeleteMessages deletes all messages in the current conversation. + deleteMessages: () => void; + + // SetConversationTitle sets the title of a conversation. + setConversationTitle: (conversationId: string, title: string) => void; + + // OnMessage returns the current onMessage callback. + onMessage: (msg: GenerateTextResponse) => void; + + // SetOnMessage sets the onMessage callback. This callback is called whenever a new message is generated by the model. + setOnMessage: (cb: (msg: GenerateTextResponse) => void) => void; + + // UserRoleName returns the current user role name. The default is "user". + userRoleName: string; + + // SetUserRoleName sets the user role name. + setUserRoleName: (roleName: string) => void; + + // AssistantRoleName returns the current assistant role name. The default is "assistant". + assistantRoleName: string; + + // SetAssistantRoleName sets the assistant role name. + setAssistantRoleName: (roleName: string) => void; + + // GpuDevice returns the current GPU device info. If GPU is not supported, this will return an object with unsupportedReason set. + gpuDevice: GPUDeviceInfo; + + // Send sends a message to the model for generation. + send: (text: string, maxToken: number, stopSequences: string[]) => void; + + // Init initializes the model. + init: () => void; +}; + +export const useLLMContext = (): UseLLMResponse => { + const [loadingStatus, setLoadingStatus] = + useState(initialProgress); + const [isGenerating, setIsGenerating] = useState(false); + const workerRef = useRef>(); + const cStore = useStore(useConversationStore, (state) => state); + const [userRoleName, setUserRoleName] = useState("user"); + const [assistantRoleName, setAssistantRoleName] = + useState("assistant"); + + const [gpuDevice, setGpuDevice] = useState({ + adapter: null, + device: null, + adapterInfo: null, + checked: false, + unsupportedReason: null, + }); + + useEffect(() => { + if (!gpuDevice || !gpuDevice.checked) { + detectGPUDevice() + .then((resp) => { + if (resp) { + setGpuDevice({ + unsupportedReason: null, + checked: true, + adapter: resp.adapter, + device: resp.device, + adapterInfo: resp.adapterInfo, + }); + } else { + setGpuDevice({ + ...gpuDevice, + checked: true, + unsupportedReason: "GPU is not supported", + }); + } + }) + .catch((err) => { + setGpuDevice({ + adapter: null, + device: null, + adapterInfo: null, + checked: true, + unsupportedReason: err.message, + }); + }); + } + }, []); + + const [onMessage, setOnMessage] = useState(); + + const addMessage = useCallback( + (resp: GenerateTextResponse) => { + if (resp.isFinished) { + setIsGenerating(false); + } + if (onMessage) onMessage(resp); + cStore?.addMessage(cStore?.currentConversationId, { + id: resp.requestId, + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + role: assistantRoleName, + text: resp.outputText, + }); + }, + [cStore, cStore?.currentConversationId, onMessage, setOnMessage] + ); + + useEffect(() => { + if (!workerRef.current) { + workerRef.current = Comlink.wrap( + new Worker(new URL("../worker/worker", import.meta.url)) + ); + } + }, []); + + const send = ( + text: string, + maxTokens = 100, + stopStrings = [userRoleName, assistantRoleName] as string[] + ) => { + const currentConversation = cStore?.getConversation( + cStore?.currentConversationId + ); + if (!currentConversation) { + throw new Error("Invalid conversation id"); + } + currentConversation?.messages.push({ + id: uuidv4(), + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + role: userRoleName, + text, + }); + setIsGenerating(true); + workerRef?.current?.generate( + { + conversation: currentConversation, + stopTexts: stopStrings, + maxTokens, + assistantRoleName, + } as GenerateTextRequest, + Comlink.proxy(addMessage) + ); + }; + + return { + conversation: cStore?.getConversation(cStore?.currentConversationId), + + allConversations: cStore?.conversations.sort( + (a: Conversation, b: Conversation) => b.updatedAt - a.updatedAt + ), + + createConversation: (title?: string, prompt?: string) => { + const id = uuidv4(); + cStore?.createConversation({ + id, + title: title ?? "Untitled", + systemPrompt: prompt ?? defaultSystemPrompt, + messages: [], + createdAt: new Date().getTime(), + updatedAt: new Date().getTime(), + }); + }, + + setConversationTitle: (id: string, title: string) => { + cStore?.setConversationTitle(id, title); + }, + + setConversationId: (id: string) => { + cStore?.setConversationId(id); + }, + + deleteConversation: (id: string) => { + cStore?.deleteConversation(id); + }, + deleteMessages: () => cStore?.deleteMessages(cStore?.currentConversationId), + + onMessage, + setOnMessage, + + loadingStatus, + isGenerating, + + userRoleName, + setUserRoleName, + + assistantRoleName, + setAssistantRoleName, + + gpuDevice, + + send, + init: () => workerRef?.current?.init(Comlink.proxy(setLoadingStatus)), + + deleteAllConversations: () => cStore?.deleteAllConversations(), + }; +}; diff --git a/packages/headless/src/hooks/useStore.tsx b/packages/headless/src/hooks/useStore.tsx new file mode 100644 index 0000000..63272bc --- /dev/null +++ b/packages/headless/src/hooks/useStore.tsx @@ -0,0 +1,18 @@ +import { useEffect, useState } from "react"; + +// https://github.com/pmndrs/zustand/blob/65d2bc0660ab0d542cf9f97a3b004754ffa73f3e/docs/integrations/persisting-store-data.md?plain=1#L471-L488 +const useStore = ( + store: (callback: (state: T) => unknown) => unknown, + callback: (state: T) => F +) => { + const result = store(callback) as F; + const [data, setData] = useState(); + + useEffect(() => { + setData(result); + }, [result]); + + return data; +}; + +export default useStore; diff --git a/packages/headless/src/index.d.ts b/packages/headless/src/index.d.ts new file mode 100644 index 0000000..f869329 --- /dev/null +++ b/packages/headless/src/index.d.ts @@ -0,0 +1,6 @@ +export type { + ModelProvider, + ModelProviderProps +} from './providers/ModelProvider' +export type * from './types/chat' + diff --git a/packages/headless/src/index.ts b/packages/headless/src/index.ts new file mode 100644 index 0000000..6cf2936 --- /dev/null +++ b/packages/headless/src/index.ts @@ -0,0 +1,5 @@ + +import { ModelProvider, useLLM } from './providers/ModelProvider'; + +export { ModelProvider }; +export default useLLM; diff --git a/packages/headless/src/providers/ModelProvider.tsx b/packages/headless/src/providers/ModelProvider.tsx new file mode 100644 index 0000000..4be20c0 --- /dev/null +++ b/packages/headless/src/providers/ModelProvider.tsx @@ -0,0 +1,24 @@ +import React, { createContext, useContext } from "react"; +import { UseLLMParams, UseLLMResponse, useLLMContext } from "../hooks/useLLM"; + +export interface ModelProviderProps { + children: React.ReactNode; + props?: UseLLMParams; +} + +const ModelContext = createContext(null); + +export const ModelProvider: React.FC = ({ children }) => { + const LLMValue = useLLMContext(); + return ( + {children} + ); +}; + +export const useLLM = (): UseLLMResponse => { + const context = useContext(ModelContext); + if (context === null) { + throw new Error("useLLMContext must be used within a LLMProvider"); + } + return context; +}; diff --git a/packages/headless/src/types/chat.ts b/packages/headless/src/types/chat.ts new file mode 100644 index 0000000..4cab670 --- /dev/null +++ b/packages/headless/src/types/chat.ts @@ -0,0 +1,16 @@ +export interface Conversation { + id: string; + title: string; + systemPrompt: string; + createdAt: number; + updatedAt: number; + messages: Message[]; +} + +export interface Message { + id: string; + role: string; + text: string; + createdAt: number; + updatedAt: number; +} \ No newline at end of file diff --git a/packages/headless/src/types/worker_message.ts b/packages/headless/src/types/worker_message.ts new file mode 100644 index 0000000..c131941 --- /dev/null +++ b/packages/headless/src/types/worker_message.ts @@ -0,0 +1,32 @@ +import * as Comlink from 'comlink'; +import { InitProgressCallback } from "../worker/lib/tvm/runtime"; +import { Conversation } from "./chat"; + +export type ModelWorker = { + init(callback: Comlink.ProxyOrClone): void; + generate(request: GenerateTextRequest, callback: Comlink.ProxyOrClone): void; +} + +export type InitCallback = InitProgressCallback; +export type GenerateTextCallback = (data: GenerateTextResponse) => void; + +export type GenerateTextRequest = { + conversation: Conversation, + stopTexts: string[], + maxTokens: number, + assistantRoleName: string, +} + +export type GenerateTextResponse = { + requestId: string, + step: number, + outputText: string, + stats: { + totalDecodingSeconds: number, + totalDecodedTokens: number, + totalEncodedTokens: number, + } + isFinished: boolean, +} + + diff --git a/packages/headless/src/worker/lib/tvm/compact.d.ts b/packages/headless/src/worker/lib/tvm/compact.d.ts new file mode 100644 index 0000000..866a6af --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/compact.d.ts @@ -0,0 +1,10 @@ +/** NodeJS and Web compact layer */ +/** + * Get performance measurement. + */ +export declare function getPerformance(): Performance; +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export declare function createWebSocket(url: string): WebSocket; diff --git a/packages/headless/src/worker/lib/tvm/compact.ts b/packages/headless/src/worker/lib/tvm/compact.ts new file mode 100644 index 0000000..0a0ec51 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/compact.ts @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** NodeJS and Web compact layer */ + +/** + * Get performance measurement. + */ +export function getPerformance(): Performance { + return performance; +} + +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export function createWebSocket(url: string): WebSocket { + return new WebSocket(url); +} diff --git a/packages/headless/src/worker/lib/tvm/ctypes.d.ts b/packages/headless/src/worker/lib/tvm/ctypes.d.ts new file mode 100644 index 0000000..fac9218 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/ctypes.d.ts @@ -0,0 +1,180 @@ +/** + * Types for C API. + */ +/** A pointer to points to the raw address space. */ +export type Pointer = number; +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = (mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = (func: Pointer, argValues: Pointer, typeCode: Pointer, nargs: number, retValue: Pointer, retCode: Pointer) => number; +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = (ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = (name: Pointer, f: Pointer, override: number) => number; +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = (shape: Pointer, ndim: number, dtypeCode: number, dtypeBits: number, dtypeLanes: number, deviceType: number, deviceId: number, out: Pointer) => number; +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = (handle: Pointer, data: Pointer, nbytes: number) => number; +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = (handle: Pointer, data: Pointer, nbytes: number) => number; +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = (from: Pointer, to: Pointer, stream: Pointer) => number; +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = (deviceType: number, deviceId: number, stream: Pointer) => number; +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = (argValues: Pointer, argCodes: Pointer, nargs: number, outValue: Pointer, outCode: Pointer) => number; +/** + * int TVMObjectFree(TVMObjectHandle obj); + */ +export type FTVMObjectFree = (obj: Pointer) => number; +/** + * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + */ +export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; +/** + * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + */ +export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; +/** + * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + */ +export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = (args: Pointer, typeCodes: Pointer, nargs: number, ret: Pointer, resourceHandle: Pointer) => number; +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = (resource: Pointer, out: Pointer) => number; +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; +/** + * Size of common data types. + */ +export declare const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = 4, + DLDevice = 8 +} +/** + * Argument Type code in TVM FFI. + */ +export declare const enum ArgTypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + DLDevice = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} diff --git a/packages/headless/src/worker/lib/tvm/ctypes.ts b/packages/headless/src/worker/lib/tvm/ctypes.ts new file mode 100644 index 0000000..282679f --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/ctypes.ts @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; + +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + + +/** + * int TVMObjectFree(TVMObjectHandle obj); + */ + export type FTVMObjectFree = (obj: Pointer) => number; + +/** + * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + */ +export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; + +/** + * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + */ +export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; + +/** + * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + */ +export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLDevice = I32 + I32, +} + +/** + * Argument Type code in TVM FFI. + */ +export const enum ArgTypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + DLDevice = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} diff --git a/packages/headless/src/worker/lib/tvm/environment.d.ts b/packages/headless/src/worker/lib/tvm/environment.d.ts new file mode 100644 index 0000000..5e935f7 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/environment.d.ts @@ -0,0 +1,26 @@ +import { LibraryProvider } from "./types"; +import * as ctypes from "./ctypes"; +/** + * Environment to impelement most of the JS library functions. + */ +export declare class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array; + private libProvider?; + constructor(importObject?: Record, logger?: (msg: string) => void); + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void; + private environment; +} diff --git a/packages/headless/src/worker/lib/tvm/environment.ts b/packages/headless/src/worker/lib/tvm/environment.ts new file mode 100644 index 0000000..24126c0 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} diff --git a/packages/headless/src/worker/lib/tvm/index.d.ts b/packages/headless/src/worker/lib/tvm/index.d.ts new file mode 100644 index 0000000..2196fda --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/index.d.ts @@ -0,0 +1,6 @@ +export { RPCServer } from "./rpc_server"; +export { DLDataType, DLDevice, Instance, Module, NDArray, Scalar, TVMArray, instantiate } from "./runtime"; +export type { PackedFunc } from "./runtime"; +export { assert, wasmPath } from "./support"; +export type { Disposable, LibraryProvider } from "./types"; +export { detectGPUDevice } from "./webgpu"; diff --git a/packages/headless/src/worker/lib/tvm/index.ts b/packages/headless/src/worker/lib/tvm/index.ts new file mode 100644 index 0000000..4772e1b --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/index.ts @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { RPCServer } from "./rpc_server"; +export { DLDataType, DLDevice, Instance, Module, NDArray, Scalar, TVMArray, instantiate } from "./runtime"; +export type { PackedFunc } from "./runtime"; +export { assert, wasmPath } from "./support"; +export type { Disposable, LibraryProvider } from "./types"; +export { detectGPUDevice } from "./webgpu"; + diff --git a/packages/headless/src/worker/lib/tvm/memory.d.ts b/packages/headless/src/worker/lib/tvm/memory.d.ts new file mode 100644 index 0000000..9c16d6b --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/memory.d.ts @@ -0,0 +1,144 @@ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset } from "./ctypes"; +import { Disposable } from "./types"; +import * as ctypes from "./ctypes"; +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export declare class Memory { + memory: WebAssembly.Memory; + wasm32: boolean; + private buffer; + private viewU8; + private viewU16; + private viewI32; + private viewU32; + private viewF32; + private viewF64; + constructor(memory: WebAssembly.Memory); + loadU8(ptr: Pointer): number; + loadU16(ptr: Pointer): number; + loadU32(ptr: Pointer): number; + loadI32(ptr: Pointer): number; + loadI64(ptr: Pointer): number; + loadF32(ptr: Pointer): number; + loadF64(ptr: Pointer): number; + loadPointer(ptr: Pointer): Pointer; + loadUSize(ptr: Pointer): Pointer; + sizeofPtr(): number; + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array; + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array; + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string; + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void; + /** + * Update memory view after the memory growth. + */ + private updateViews; +} +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export declare class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array; + private memory; + private cAllocSpace; + private cFreeSpace; + private buffer; + private viewU8; + private viewI32; + private viewU32; + private viewF64; + private stackTop; + private basePtr; + private addressToSetTargetValue; + constructor(memory: Memory, allocSpace: ctypes.FTVMWasmAllocSpace, freeSpace: ctypes.FTVMWasmFreeSpace); + dispose(): void; + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void; + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes?: number): void; + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset; + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset; + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer; + storePtr(offset: PtrOffset, value: Pointer): void; + storeUSize(offset: PtrOffset, value: Pointer): void; + storeI32(offset: PtrOffset, value: number): void; + storeU32(offset: PtrOffset, value: number): void; + storeI64(offset: PtrOffset, value: number): void; + storeF64(offset: PtrOffset, value: number): void; + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void; + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void; + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void; + /** + * Update internal cache views. + */ + private updateViews; +} diff --git a/packages/headless/src/worker/lib/tvm/memory.ts b/packages/headless/src/worker/lib/tvm/memory.ts new file mode 100644 index 0000000..ac737b7 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/packages/headless/src/worker/lib/tvm/rpc_server.d.ts b/packages/headless/src/worker/lib/tvm/rpc_server.d.ts new file mode 100644 index 0000000..af9928a --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/rpc_server.d.ts @@ -0,0 +1,54 @@ +import * as runtime from "./runtime"; +declare enum RPCServerState { + InitHeader = 0, + InitHeaderKey = 1, + InitServer = 2, + WaitForCallback = 3, + ReceivePacketHeader = 4, + ReceivePacketBody = 5 +} +/** + * A websocket based RPC + */ +export declare class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState; + logger: (msg: string) => void; + getImports: () => Record; + private ndarrayCacheUrl; + private ndarrayCacheDevice; + private initProgressCallback?; + private asyncOnServerLoad?; + private pendingSend; + private name; + private inst?; + private globalObjects; + private serverRecvData?; + private currPacketHeader?; + private currPacketLength; + private remoteKeyLength; + private pendingBytes; + private buffredBytes; + private messageQueue; + constructor(url: string, key: string, getImports: () => Record, logger?: (msg: string) => void, ndarrayCacheUrl?: string, ndarrayCacheDevice?: string, initProgressCallback?: runtime.InitProgressCallback | undefined, asyncOnServerLoad?: ((inst: runtime.Instance) => Promise) | undefined); + private onClose; + private onOpen; + /** Handler for raw message. */ + private onMessage; + /** Process ready events. */ + private processEvents; + /** State machine to handle each request */ + private onDataReady; + private onPacketReady; + /** Event handler during server initialization. */ + private onInitServer; + private log; + private handleInitHeader; + private handleInitHeaderKey; + private checkLittleEndian; + private requestBytes; + private readFromBuffer; +} +export {}; diff --git a/packages/headless/src/worker/lib/tvm/rpc_server.ts b/packages/headless/src/worker/lib/tvm/rpc_server.ts new file mode 100644 index 0000000..f920c2c --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/rpc_server.ts @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import * as compact from "./compact"; +import { ArgTypeCode, SizeOf } from "./ctypes"; +import * as runtime from "./runtime"; +import { StringToUint8Array, Uint8ArrayToString, assert } from "./support"; +import { Disposable } from "./types"; +import { GPUDeviceDetectOutput, detectGPUDevice } from "./webgpu"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private ndarrayCacheUrl: string; + private ndarrayCacheDevice: string; + private initProgressCallback?: runtime.InitProgressCallback; + private asyncOnServerLoad?: (inst: runtime.Instance) => Promise; + private pendingSend: Promise = Promise.resolve(); + private name: string; + private inst?: runtime.Instance = undefined; + private globalObjects: Array = []; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log, + ndarrayCacheUrl: string = "", + ndarrayCacheDevice: string = "cpu", + initProgressCallback: runtime.InitProgressCallback | undefined = undefined, + asyncOnServerLoad: ((inst: runtime.Instance) => Promise) | undefined = undefined, + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + this.ndarrayCacheUrl = ndarrayCacheUrl; + this.ndarrayCacheDevice = ndarrayCacheDevice; + this.initProgressCallback = initProgressCallback; + this.asyncOnServerLoad = asyncOnServerLoad; + this.checkLittleEndian(); + this.socket = compact.createWebSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.globalObjects.forEach(obj => { + obj.dispose(); + }); + this.log(this.inst.runtimeStatsText()); + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + this.log("Automatic reconnecting.."); + new RPCServer( + this.url, this.key, this.getImports, this.logger, + this.ndarrayCacheUrl, this.ndarrayCacheDevice, + this.initProgressCallback, this.asyncOnServerLoad); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == ArgTypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == ArgTypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(this.pendingBytes == 0); + + const asyncInitServer = async (): Promise => { + assert(args[1] instanceof Uint8Array); + const inst = await runtime.instantiate( + args[1].buffer, + this.getImports(), + this.logger + ); + + try { + const output: GPUDeviceDetectOutput | undefined = await detectGPUDevice(); + if (output !== undefined) { + const label = "WebGPU: " + output.adapterInfo.description; + this.log("Initialize GPU device: " + label); + inst.initWebGPU(output.device); + } else { + this.log("Cannot find WebGPU device in the env"); + } + } catch (err: any) { + this.log("Cannnot initialize WebGPU, " + err.toString()); + } + + this.inst = inst; + // begin scope to allow handling of objects + this.inst.beginScope(); + if (this.initProgressCallback !== undefined) { + this.inst.registerInitProgressCallback(this.initProgressCallback); + } + + if (this.ndarrayCacheUrl.length != 0) { + if (this.ndarrayCacheDevice == "cpu") { + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.cpu()); + } else { + assert(this.ndarrayCacheDevice == "webgpu"); + await this.inst.fetchNDArrayCache(this.ndarrayCacheUrl, this.inst.webgpu()); + } + } + + assert(this.inst !== undefined); + if (this.asyncOnServerLoad !== undefined) { + await this.asyncOnServerLoad(this.inst); + } + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + // WebSocket will automatically close the socket + // if we burst send data that exceeds its internal buffer + // wait a bit before we send next one. + const sendDataWithCongestionControl = async (): Promise => { + const packetSize = 4 << 10; + const maxBufferAmount = 4 * packetSize; + const waitTimeMs = 20; + for ( + let offset = 0; + offset < cbytes.length; + offset += packetSize + ) { + const end = Math.min(offset + packetSize, cbytes.length); + while (this.socket.bufferedAmount >= maxBufferAmount) { + await new Promise((r) => setTimeout(r, waitTimeMs)); + } + this.socket.send(cbytes.slice(offset, end)); + } + }; + // Chain up the pending send so that the async send is always in-order. + this.pendingSend = this.pendingSend.then( + sendDataWithCongestionControl + ); + // Directly return since the data are "sent" from the caller's pov. + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + // message handler should persist across RPC runs + this.globalObjects.push( + this.inst.detachFromCurrentScope(messageHandler) + ); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("wasm.LocalSession"); + const localSession = flocal(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + // recycle all values. + this.inst.endScope(); + }; + + this.state = RPCServerState.WaitForCallback; + asyncInitServer(); + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/packages/headless/src/worker/lib/tvm/runtime.d.ts b/packages/headless/src/worker/lib/tvm/runtime.d.ts new file mode 100644 index 0000000..8a73838 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/runtime.d.ts @@ -0,0 +1,700 @@ +/// +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset } from "./ctypes"; +import { Environment } from "./environment"; +import { CachedCallStack, Memory } from "./memory"; +import { Disposable } from "./types"; +import { WebGPUContext } from "./webgpu"; +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & Disposable & { + _tvmPackedCell: PackedFuncCell; +}; +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +declare class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + webGPUContext?: WebGPUContext; + private wasmInstance; + private recycledCallStacks; + constructor(wasmInstance: WebAssembly.Instance, imports: Record); + dispose(): void; + sizeofPtr(): number; + checkCall(code: number): void; + getOrAllocCallStack(): CachedCallStack; + recycleCallStack(callstack: CachedCallStack): void; + private validateInstance; + private checkExports; + private detectWasmMemory; +} +/** + * @internal + * Manages extra runtime context for the runtime. + */ +declare class RuntimeContext implements Disposable { + arrayGetItem: PackedFunc; + arrayGetSize: PackedFunc; + arrayMake: PackedFunc; + getSysLib: PackedFunc; + arrayCacheGet: PackedFunc; + arrayCacheUpdate: PackedFunc; + arrayCacheRemove: PackedFunc; + arrayCacheClear: PackedFunc; + arrayDecodeStorage: PackedFunc; + paramModuleFromCache: PackedFunc; + makeShapeTuple: PackedFunc; + ndarrayCreateView: PackedFunc; + sampleTopPFromLogits: PackedFunc; + private autoDisposeScope; + constructor(getGlobalFunc: (name: string) => PackedFunc); + dispose(): void; + beginScope(): void; + endScope(): void; + /** + * Track object for dispose in current scope. + * + * @param obj The object to be tracked. + * @returns the same object. + * @note This function only needs to be called for raw system C API values. + * The return value of PackedFunc will be automatically tracked. + */ + attachToCurrentScope(obj: T): T; + moveToParentScope(obj: T): T; + detachFromCurrentScope(obj: T): T; +} +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export declare class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + constructor(value: number, dtype: string); +} +/** + * Cell holds the PackedFunc object. + */ +declare class PackedFuncCell implements Disposable { + private handle; + private lib; + constructor(handle: Pointer, lib: FFILibrary); + dispose(): void; + getHandle(requireNotNull?: boolean): Pointer; +} +/** + * Represent a runtime context where a NDArray can reside. + */ +export declare class DLDevice { + /** The device type code of the device. */ + deviceType: number; + /** The device index. */ + deviceId: number; + private lib; + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary); + /** + * Synchronize the device + */ + sync(): Promise; + toString(): string; +} +/** + * The data type code in DLDataType + */ +export declare const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} +/** + * Runtime data type of NDArray. + */ +export declare class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + constructor(code: number, bits: number, lanes: number); + toString(): string; + numStorageBytes(): number; +} +/** + * n-dimnesional array. + */ +export declare class NDArray implements Disposable { + /** Internal array handle. */ + private handle; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Device of the array. */ + device: DLDevice; + /** Whether it is a temporary view that can become invalid after the call. */ + isView: boolean; + private byteOffset; + private dltensor; + private dataPtr; + private lib; + private ctx; + private dlDataType; + constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: RuntimeContext); + /** + * Create a view of the array. + * @param shape The shape of the view. + * @returns The new sliced ndarray. + */ + view(shape: Array): NDArray; + /** + * Get handle of ndarray, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** + * Get dataPtr of NDarray + * + * @returns The handle. + */ + getDataPtr(): Pointer; + dispose(): void; + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array | Float32Array): this; + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this; + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array; + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array; + private getDLTensorFromArrayHandle; +} +/** + * Runtime Module. + */ +export declare class Module implements Disposable { + private handle; + private lib; + private makePackedFunc; + constructor(handle: Pointer, lib: FFILibrary, makePackedFunc: (ptr: Pointer) => PackedFunc); + dispose(): void; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** + * Get a function in the module. + * @param name The name of the function. + * @param queryImports Whether to also query imports + * @returns The result function. + */ + getFunction(name: string, queryImports?: boolean): PackedFunc; + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void; +} +/** + * Generic object base + */ +export declare class TVMObject implements Disposable { + private handle; + private lib; + protected ctx: RuntimeContext; + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext); + dispose(): void; + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull?: boolean): Pointer; + /** get the type index of the object */ + typeIndex(): number; + /** get the type key of the object */ + typeKey(): string; +} +/** Objectconstructor */ +type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; +/** All possible object types. */ +type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; +/** Runtime array object. */ +export declare class TVMArray extends TVMObject { + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext); + /** + * @returns the size of the array. + */ + size(): number; + /** + * Get index-th element of the array + * @param index the array index. + * @returns The element. + */ + get(index: number): TVMObjectBase; +} +export declare const enum VMAllocatorKind { + NAIVE_ALLOCATOR = 1, + POOLED_ALLOCATOR = 2 +} +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +export declare class VirtualMachine implements Disposable { + private mod; + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + constructor(mod: Module, device: DLDevice); + dispose(): void; + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc; + /** + * Get the internal module. + */ + getInternalModule(): Module; +} +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + byteOffset: number; + nbytes: number; +} +export interface NDArrayShardEntry { + dataPath: string; + format: "raw-shard"; + nbytes: number; + records: Array; +} +export interface InitProgressReport { + type: 'init'; + progress: number; + timeElapsed: number; + currentChunk: number; + totalChunks: number; + fetchedBytes: number; + totalBytes: number; +} +export type InitProgressCallback = (report: InitProgressReport) => void; +/** + * TVM runtime instance. + * + * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * and PackedFunc instance are tracked through a scope mechanism that will get + * auto-released when we call EndScope. + * + * This is necessarily to be able to release the underlying WASM and WebGPU memory that + * are not tracked through JS native garbage collection mechanism. + * + * This does mean that we have to get familar with the following functions: + * - {@link beginScope} + * - {@link endScope} + * - {@link withNewScope} + * - {@link attachToCurrentScope} + * - {@link detachFromCurrentScope} + */ +export declare class Instance implements Disposable { + memory: Memory; + exports: Record; + cacheMetadata: Record; + private lib; + private env; + private objFactory; + private ctx; + private initProgressCallback; + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?; + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor(wasmModule: WebAssembly.Module, importObject?: Record, wasmInstance?: WebAssembly.Instance, env?: Environment); + /** + * Benchmark stable execution of the run function. + * + * @params run The run function + * @params dev The device to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + benchmark(run: () => void, dev: DLDevice, number?: number, repeat?: number): Promise; + dispose(): void; + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string; + /** + * Begin a new scope for tracking object disposal. + */ + beginScope(): void; + /** + * End a scope and release all created TVM objects + * under the current scope. + * + * Exception: one can call {@link moveToParentScope} to move + * a value to parent scope. + */ + endScope(): void; + /** + * Perform action under a new scope. + * + * @param action The action function. + * @returns The result value. + * + * @note For action to return a valid value, + * we will need to call {@link moveToParentScope} + * for the objects that are created in the scope. + */ + withNewScope(action: () => T): T; + /** + * Attach a detached obj to the auto-release pool of the current scope. + * + * @param obj The input obj. + * @note Normally user do not need to call this function explicitly, as + * all library call return values are explicitly attached to + * the current scope. You only need to do so when you call + * {@link detachFromCurrentScope} to create a detached object. + */ + attachToCurrentScope(obj: T): T; + /** + * Move obj's attachment to the parent scope. + * + * This function is useful to make sure objects are still + * alive when exit the current scope. + * + * @param obj The object to be moved. + * @returns The input obj. + */ + moveToParentScope(obj: T): T; + /** + * Detach the object from the current scope + * so it won't be released via auto-release during endscope. + * + * User needs to either explicitly call obj.dispose(), or + * {@link attachToCurrentScope} to re-attach to the current scope. + * + * This function can be used to return values to the parent scope. + * @param obj The object. + */ + detachFromCurrentScope(obj: T): T; + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module; + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array; + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc(name: string, func: PackedFunc | Function, override?: boolean): void; + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @param autoAttachToScope Whether to track it via autoDispose + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc; + private getGlobalFuncInternal; + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean; + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc; + private toPackedFuncInternal; + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + createVirtualMachine(dev: DLDevice): VirtualMachine; + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + registerInitProgressCallback(cb: InitProgressCallback): void; + /** + * Get parameters in the form of prefix_i + * + * @param prefix The parameter prefix. + * @param numParams Number of parameters. + * @returns + */ + getParamsFromCache(prefix: string, numParams: number): TVMObject; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheGet(name: string): NDArray | undefined; + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheRemove(name: string): NDArray | undefined; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheUpdate(name: string, arr: NDArray, override?: boolean): void; + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheClear(): void; + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + * @returns The meta data + */ + fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice): Promise; + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + private fetchNDArrayCacheInternal; + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType; + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar; + /** + * Create a new {@link DLDevice} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created device. + */ + device(deviceType: number | string, deviceId?: number): DLDevice; + /** + * Create a new cpu {@link DLDevice} + * @param deviceId The device index. + */ + cpu(deviceId?: number): DLDevice; + /** + * Create a new webgpu {@link DLDevice} + * @param deviceId The device index. + */ + webgpu(deviceId?: number): DLDevice; + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + empty(shape: Array | number, dtype?: string | DLDataType, dev?: DLDevice): NDArray; + /** + * Create am uniform {@link NDArray} with given shape. + * + * @param shape The shape of the array. + * @param low The low value. + * @param high The high value. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + uniform(shape: Array, low: number, high: number, dev: DLDevice): NDArray; + /** + * Sample index via top-p sampling. + * + * @param logits The input logits before normalization. + * @param temperature The temperature factor, will take argmax if temperature = 0.0 + * @param top_p The top_p + * @returns The sampled index. + */ + sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number; + /** + * Bind canvas to the current WebGPU context + * @param canvas The canvas. + */ + bindCanvas(canvas: HTMLCanvasElement): void; + /** + * Show image in canvas. + * + * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + */ + showImage(dataRGBA: NDArray): void; + /** + * Clear canvas + */ + clearCanvas(): void; + /** + * Create an tuple {@link TVMArray} input array. + * + * The input array can be passed to tvm runtime function + * and needs to b explicitly disposed. + * + * @param inputs The input array + * @returns The result array. + */ + makeTVMArray(inputs: Array): TVMArray; + /** + * Create a shape tuple to pass to runtime. + * @param shape The shape . + * @returns The created shape tuple. + */ + makeShapeTuple(shape: Array): TVMObject; + /** + * Get type index from type key. + * @param typeKey The type key. + * @returns The corresponding type index. + */ + typeKey2Index(typeKey: string): number; + /** + * Register an object constructor. + * @param typeKey The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerObjectConstructor(typeKey: string, func: FObjectConstructor, override?: boolean): void; + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc(name: string, func: Function, override?: boolean): void; + /** + * Asynchrously load webgpu pipelines when possible. + * @param mod The input module. + */ + asyncLoadWebGPUPiplines(mod: Module): Promise; + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void; + /** Register all object factory */ + private registerObjectFactoryFuncs; + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs; + private createPackedFuncFromCFunc; + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments(stack: CachedCallStack, args: Array, argsValue: PtrOffset, argsCode: PtrOffset): void; + private wrapJSFuncAsPackedCFunc; + private makePackedFunc; + /** + * Creaye return value of the packed func. The value us auto-tracked for dispose. + * @param rvaluePtr The location of rvalue + * @param tcode The type code. + * @param callbackArg Whether it is being used in callbackArg. + * @returns The JS value. + */ + private retValueToJS; +} +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +export declare function instantiate(bufferSource: ArrayBuffer, importObject?: Record, logger?: (msg: string) => void): Promise; +export {}; diff --git a/packages/headless/src/worker/lib/tvm/runtime.ts b/packages/headless/src/worker/lib/tvm/runtime.ts new file mode 100644 index 0000000..af54b72 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/runtime.ts @@ -0,0 +1,2280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { ArgTypeCode, Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Environment } from "./environment"; +import { CachedCallStack, Memory } from "./memory"; +import { StringToUint8Array, assert } from "./support"; +import { Disposable } from "./types"; +import { FunctionInfo, WebGPUContext } from "./webgpu"; + +import * as compact from "./compact"; +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + webGPUContext?: WebGPUContext; + private wasmInstance: WebAssembly.Instance; + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + this.webGPUContext?.dispose(); + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * @internal + * Manages extra runtime context for the runtime. + */ +class RuntimeContext implements Disposable { + arrayGetItem: PackedFunc; + arrayGetSize: PackedFunc; + arrayMake: PackedFunc; + getSysLib: PackedFunc; + arrayCacheGet: PackedFunc; + arrayCacheUpdate: PackedFunc; + arrayCacheRemove: PackedFunc; + arrayCacheClear: PackedFunc; + arrayDecodeStorage: PackedFunc; + paramModuleFromCache: PackedFunc; + makeShapeTuple: PackedFunc; + ndarrayCreateView: PackedFunc; + sampleTopPFromLogits: PackedFunc; + + private autoDisposeScope: Array> = []; + + constructor(getGlobalFunc: (name: string) => PackedFunc) { + this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); + this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); + this.arrayMake = getGlobalFunc("runtime.Array"); + this.getSysLib = getGlobalFunc("runtime.SystemLib"); + this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get"); + this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove"); + this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update"); + this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear"); + this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); + this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); + this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); + this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); + this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); + } + + dispose(): void { + // call array cache clear to clear all cached items + this.arrayCacheClear.dispose(); + this.arrayGetItem.dispose(); + this.arrayGetSize.dispose(); + this.arrayMake.dispose(); + this.arrayCacheGet.dispose(); + this.arrayCacheRemove.dispose(); + this.arrayCacheUpdate.dispose(); + this.arrayCacheClear.dispose(); + this.arrayDecodeStorage.dispose(); + this.paramModuleFromCache.dispose(); + this.makeShapeTuple.dispose(); + this.ndarrayCreateView.dispose(); + this.sampleTopPFromLogits.dispose(); + } + + beginScope(): void { + this.autoDisposeScope.push([]); + } + + endScope(): void { + if (this.autoDisposeScope.length == 0) { + throw Error("tvm.endScope called when the stack is empty."); + } + // automatically dispose all the tracked values in the current scope. + const currScope = this.autoDisposeScope.pop() as Array; + for (let i = 0; i < currScope.length; ++i) { + const val = currScope[i]; + if (val !== undefined) { + val.dispose(); + } + } + } + + /** + * Track object for dispose in current scope. + * + * @param obj The object to be tracked. + * @returns the same object. + * @note This function only needs to be called for raw system C API values. + * The return value of PackedFunc will be automatically tracked. + */ + attachToCurrentScope(obj: T): T { + if (this.autoDisposeScope.length == 0) { + throw Error("Must call beginScope to use functions that returns TVM objects"); + } + const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + currScope.push(obj); + return obj; + } + + moveToParentScope(obj: T): T { + this.detachFromCurrentScope(obj); + if (this.autoDisposeScope.length < 2) { + throw Error("moveToParentScope: Parent scope do not exist"); + } + const parentScope = this.autoDisposeScope[this.autoDisposeScope.length - 2]; + parentScope.push(obj); + return obj; + } + + detachFromCurrentScope(obj: T): T { + const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; + let occurance = 0; + for (let i = 0; i < currScope.length; ++i) { + if (currScope[i] === obj) { + occurance += 1; + currScope[i] = undefined; + } + } + if (occurance == 0) { + throw Error("Cannot find obj in the current auto conversion pool"); + } + if (occurance > 1) { + throw Error("Value attached to scope multiple times"); + } + return obj; + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + private handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } + + getHandle(requireNotNull: boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("PackedFunc has already been disposed"); + } + return this.handle; + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "cuda", + 4: "opencl", + 8: "metal", + 15: "webgpu" +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, + webgpu: 15 +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLDevice { + /** The device type code of the device. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + if (this.deviceType == undefined) { + throw new Error("Cannot recogonize deviceType " + deviceType); + } + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the device + */ + async sync(): Promise { + if (this.deviceType == DeviceStrToEnum.webgpu) { + assert(this.lib.webGPUContext !== undefined); + await this.lib.webGPUContext.sync(); + } + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} +/** + * The data type code in DLDataType + */ +export const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 3: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + private handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Device of the array. */ + device: DLDevice; + /** Whether it is a temporary view that can become invalid after the call. */ + isView: boolean; + private byteOffset: number; + private dltensor: Pointer; + private dataPtr: Pointer; + private lib: FFILibrary; + private ctx: RuntimeContext; + private dlDataType: DLDataType; + + constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: RuntimeContext) { + this.handle = handle; + this.isView = isView; + this.lib = lib; + this.ctx = ctx; + + if (this.isView) { + this.dltensor = handle; + } else { + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + } + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLDevice; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // dataPtr + this.dataPtr = lib.memory.loadPointer(this.dltensor); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // device + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.device = new DLDevice(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + /** + * Create a view of the array. + * @param shape The shape of the view. + * @returns The new sliced ndarray. + */ + view(shape: Array): NDArray { + const shapeArray = shape.map((value) => new Scalar(value, "int")); + return this.ctx.ndarrayCreateView(this, this.ctx.makeShapeTuple(...shapeArray)); + } + + /** + * Get handle of ndarray, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull: boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.handle; + } + + /** + * Get dataPtr of NDarray + * + * @returns The handle. + */ + getDataPtr(): Pointer { + if (this.handle == 0) { + throw Error("NDArray has already been disposed"); + } + return this.dataPtr; + } + + dispose(): void { + if (this.handle != 0 && !this.isView) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array | Float32Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.getHandle(), + this.getHandle(), + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + // short cut for gpu copy + if (this.device.deviceType == DeviceStrToEnum.webgpu) { + this.lib.webGPUContext?.copyRawBytesToBuffer(data, this.getDataPtr(), 0, data.length); + return this; + } + // CPU copy + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.getHandle(), + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + if (this.device.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); + } + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.getHandle(), + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + private handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull: boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @param queryImports Whether to also query imports + * @returns The result function. + */ + getFunction(name: string, queryImports: boolean = true): PackedFunc { + if (this.handle == 0) { + throw Error("Module has already been disposed"); + } + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.getHandle(), + stack.ptrFromOffset(nameOffset), + queryImports ? 1 : 0, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.getHandle(), + mod.getHandle() + ) + ); + } +} + +/** + * Generic object base + */ +export class TVMObject implements Disposable { + private handle: Pointer; + private lib: FFILibrary; + protected ctx: RuntimeContext; + + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + this.handle = handle; + this.lib = lib; + this.ctx = ctx; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMObjectFree as ctypes.FTVMObjectFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull: boolean = true): Pointer { + if (requireNotNull && this.handle == 0) { + throw Error("Module has already been disposed"); + } + return this.handle; + } + + /** get the type index of the object */ + typeIndex(): number { + if (this.handle == 0) { + throw Error("The current Object has already been disposed"); + } + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMObjectGetTypeIndex as ctypes.FTVMObjectGetTypeIndex)( + this.getHandle(), + outPtr + ) + ); + const result = this.lib.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return result; + } + + /** get the type key of the object */ + typeKey(): string { + const type_index = this.typeIndex(); + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.lib.exports.TVMObjectTypeIndex2Key as ctypes.FTVMObjectTypeIndex2Key)( + type_index, + outPtr + ) + ); + const result = this.lib.memory.loadCString( + this.lib.memory.loadPointer(outPtr) + ); + this.lib.recycleCallStack(stack); + return result; + } +} + +/** Objectconstructor */ +type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; + +/** All possible object types. */ +type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; + +/** Runtime array object. */ +export class TVMArray extends TVMObject { + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + super(handle, lib, ctx); + } + + /** + * @returns the size of the array. + */ + size(): number { + return this.ctx.arrayGetSize(this) as number; + } + /** + * Get index-th element of the array + * @param index the array index. + * @returns The element. + */ + get(index: number): TVMObjectBase { + return this.ctx.arrayGetItem(this, new Scalar(index, "int32")) as TVMObjectBase; + } +} + +export const enum VMAllocatorKind { + NAIVE_ALLOCATOR = 1, + POOLED_ALLOCATOR = 2, +} + +/** + * VirtualMachine Executor. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +export class VirtualMachine implements Disposable { + private mod: Module; + /** + * Constructor + * @param mod The underlying module, need to be detached. + * @param device The main device ro run VM on. + */ + constructor(mod: Module, device: DLDevice) { + this.mod = mod; + this.mod.getFunction("vm_initialization")( + new Scalar(device.deviceType, "int"), + new Scalar(device.deviceId, "int"), + new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"), + // explicitly specify host device type + new Scalar(DeviceStrToEnum.cpu, "int"), + new Scalar(0, "int"), + new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"), + ); + } + + dispose(): void { + this.mod.dispose(); + } + /** + * Get a function in the VM module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + return this.mod.getFunction(name); + } + + /** + * Get the internal module. + */ + getInternalModule(): Module { + return this.mod; + } +} + +/** Code used as the first argument of the async callback. */ +const enum AyncCallbackCode { + kReturn = 4, + kException = 5, +} +export interface NDArrayCacheEntry { + name: string; + shape: Array; + dtype: string; + format: "f32-to-bf16" | "raw"; + byteOffset: number; + nbytes: number; +} + +export interface NDArrayShardEntry { + dataPath: string; + format: "raw-shard"; + nbytes: number; + records: Array; +} + +export interface InitProgressReport { + type: 'init'; + progress: number; + timeElapsed: number; + currentChunk: number; + totalChunks: number; + fetchedBytes: number; + totalBytes: number; +} + +export type InitProgressCallback = (report: InitProgressReport) => void; + +/** + * TVM runtime instance. + * + * All objects(NDArray, Module, PackedFunc) returned by TVM runtim function call + * and PackedFunc instance are tracked through a scope mechanism that will get + * auto-released when we call EndScope. + * + * This is necessarily to be able to release the underlying WASM and WebGPU memory that + * are not tracked through JS native garbage collection mechanism. + * + * This does mean that we have to get familar with the following functions: + * - {@link beginScope} + * - {@link endScope} + * - {@link withNewScope} + * - {@link attachToCurrentScope} + * - {@link detachFromCurrentScope} + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + cacheMetadata: Record = {}; + private lib: FFILibrary; + private env: Environment; + private objFactory: Map; + private ctx: RuntimeContext; + private initProgressCallback: Array = []; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.objFactory = new Map(); + this.ctx = new RuntimeContext( + (name: string) => { + const autoAttachToScope = false; + // runtime context function do not auto-release. + return this.getGlobalFuncInternal(name, autoAttachToScope); + } + ); + this.registerEnvGlobalPackedFuncs(); + this.registerObjectFactoryFuncs(); + } + + /** + * Benchmark stable execution of the run function. + * + * @params run The run function + * @params dev The device to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + async benchmark(run: () => void, dev: DLDevice, number = 10, repeat = 1): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPerformance(); + const results = []; + + // run with new scope + this.withNewScope(run); + await dev.sync(); + + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.withNewScope(run); + } + await dev.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); + } + return results; + } + + dispose(): void { + // order matters + // ctx release goes back into lib. + this.ctx.dispose(); + this.lib.dispose(); + } + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string { + if (this.lib.webGPUContext !== undefined) { + return this.lib.webGPUContext.runtimeStatsText(); + } else { + return ""; + } + } + + /** + * Begin a new scope for tracking object disposal. + */ + beginScope(): void { + this.ctx.beginScope(); + } + + /** + * End a scope and release all created TVM objects + * under the current scope. + * + * Exception: one can call {@link moveToParentScope} to move + * a value to parent scope. + */ + endScope(): void { + this.ctx.endScope(); + } + + /** + * Perform action under a new scope. + * + * @param action The action function. + * @returns The result value. + * + * @note For action to return a valid value, + * we will need to call {@link moveToParentScope} + * for the objects that are created in the scope. + */ + withNewScope(action: () => T): T { + this.beginScope(); + const val = action(); + this.endScope(); + return val; + } + + /** + * Attach a detached obj to the auto-release pool of the current scope. + * + * @param obj The input obj. + * @note Normally user do not need to call this function explicitly, as + * all library call return values are explicitly attached to + * the current scope. You only need to do so when you call + * {@link detachFromCurrentScope} to create a detached object. + */ + attachToCurrentScope(obj: T): T { + return this.ctx.attachToCurrentScope(obj); + } + + /** + * Move obj's attachment to the parent scope. + * + * This function is useful to make sure objects are still + * alive when exit the current scope. + * + * @param obj The object to be moved. + * @returns The input obj. + */ + moveToParentScope(obj: T): T { + return this.ctx.moveToParentScope(obj); + } + + /** + * Detach the object from the current scope + * so it won't be released via auto-release during endscope. + * + * User needs to either explicitly call obj.dispose(), or + * {@link attachToCurrentScope} to re-attach to the current scope. + * + * This function can be used to return values to the parent scope. + * @param obj The object. + */ + detachFromCurrentScope(obj: T): T { + return this.ctx.detachFromCurrentScope(obj); + } + + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + return this.ctx.getSysLib() as Module; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + this.withNewScope(() => { + const autoAttachToScope = true; + // packed func can be released once it is registered + const packedFunc = this.toPackedFuncInternal(func, autoAttachToScope); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.getHandle(), + ioverride + ) + ); + this.lib.recycleCallStack(stack); + }); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @param autoAttachToScope Whether to track it via autoDispose + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + return this.getGlobalFuncInternal(name, true); + } + + private getGlobalFuncInternal(name: string, autoAttachToScope: boolean = true): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + if (autoAttachToScope) this.ctx.attachToCurrentScope(ret); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + return this.toPackedFuncInternal(func, true); + } + + private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + const ret = this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret); + return ret; + } + + /** + * Setup a virtual machine module with given device. + * + * @param dev DLDevice the device. + * @returns The created virtual machime. + */ + createVirtualMachine(dev: DLDevice): VirtualMachine { + const mod = this.ctx.detachFromCurrentScope( + this.systemLib().getFunction("vm_load_executable")() + ); + return this.ctx.attachToCurrentScope( + new VirtualMachine(mod, dev) + ); + } + + //----------------------------------------------- + // Native NDArray Cache Support + //----------------------------------------------- + /** + * Register a call back for fetch progress. + * + * @param cb the fetch progress callback. + */ + registerInitProgressCallback(cb: InitProgressCallback) { + this.initProgressCallback.push(cb); + } + + /** + * Get parameters in the form of prefix_i + * + * @param prefix The parameter prefix. + * @param numParams Number of parameters. + * @returns + */ + getParamsFromCache(prefix: string, numParams: number): TVMObject { + return (this.ctx.paramModuleFromCache( + prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")(); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheGet(name: string): NDArray | undefined { + return this.ctx.arrayCacheGet(name); + } + + /** + * Get NDArray from cache. + * @param name The name of array. + * @returns The result. + */ + ndarrayCacheRemove(name: string): NDArray | undefined { + return this.ctx.arrayCacheRemove(name); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheUpdate(name: string, arr: NDArray, override: boolean = false) { + this.ctx.arrayCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); + } + + /** + * Update the ndarray cache. + * @param name The name of the array. + * @param arr The content. + */ + ndarrayCacheClear() { + this.ctx.arrayCacheClear(); + } + + /** + * Fetch NDArray cache from url. + * + * @param ndarrayCacheUrl The cache url. + * @param device The device to be fetched to. + * @returns The meta data + */ + async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice): Promise { + const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href; + const request = new Request(jsonUrl); + const cache = await caches.open("tvmjs"); + let result = await cache.match(request); + if (result === undefined) { + await cache.add(request); + result = await cache.match(request); + } + if (result === undefined) { + this.env.logger("Error: Cannot cache " + jsonUrl + ", reloading will be slow"); + try { + result = await fetch(request); + } catch (err) { + this.env.logger("Cannot fetch " + jsonUrl); + } + } + let list; + if (result instanceof Response) { + list = await result.json(); + } + await this.fetchNDArrayCacheInternal( + ndarrayCacheUrl, + list["records"] as Array, device); + this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record) }; + } + + /** + * Fetch list of NDArray into the NDArrayCache. + * + * @param ndarrayCacheUrl The cache url. + * @param list The list of array data. + * @param device The device to store the data to. + */ + private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list: Array, device: DLDevice) { + const perf = compact.getPerformance(); + let tstart = perf.now(); + + let totalBytes = 0; + for (let i = 0; i < list.length; ++i) { + totalBytes += list[i].nbytes; + }; + let fetchedBytes = 0; + let timeElapsed = 0; + + const reportCallback = (iter: number) => { + // report + for (let j = 0; j < this.initProgressCallback.length; ++j) { + this.initProgressCallback[j]({ + type: 'init', + progress: fetchedBytes / totalBytes, + timeElapsed: timeElapsed, + currentChunk: iter, + totalChunks: list.length, + fetchedBytes, + totalBytes, + }); + } + }; + + for (let j = 0; j < this.initProgressCallback.length; ++j) { + this.initProgressCallback[j]({ + type: 'init', + progress: fetchedBytes / totalBytes, + timeElapsed: 0, + currentChunk: 0, + totalChunks: list.length, + fetchedBytes, + totalBytes, + }); + } + const cache = await caches.open("tvmjs"); + + for (let i = 0; i < list.length; ++i) { + reportCallback(i); + fetchedBytes += list[i].nbytes; + const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href; + const request = new Request(dataUrl); + let buffer; + try { + // use native cache + let result = await cache.match(request); + if (result === undefined) { + await cache.add(request); + result = await cache.match(request); + } + if (result == undefined) { + this.env.logger("Error: Cannot cache " + dataUrl + ", reloading will be slow"); + result = await fetch(request); + } + buffer = await result.arrayBuffer(); + } catch (err) { + this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); + throw err; + } + const shardRecords = list[i].records; + for (let j = 0; j < shardRecords.length; ++j) { + const rec = shardRecords[j]; + const cpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, this.cpu()) + ) + }); + const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + // first sync copy to cpu. + this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format); + // then async stream into GPU if needed + if (device.deviceType == DeviceStrToEnum.cpu) { + this.ndarrayCacheUpdate(rec.name, cpu_arr, false); + cpu_arr.dispose(); + } else { + // allocate a gpu arr and async copy to it. + const gpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, device) + ) + }); + gpu_arr.copyFrom(cpu_arr); + await device.sync(); + this.ndarrayCacheUpdate(rec.name, gpu_arr, false); + cpu_arr.dispose(); + gpu_arr.dispose(); + } + } + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + } + reportCallback(list.length); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = DLDataTypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = DLDataTypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.OpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLDevice} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created device. + */ + device(deviceType: number | string, deviceId = 0): DLDevice { + return new DLDevice(deviceType, deviceId, this.lib); + } + + /** + * Create a new cpu {@link DLDevice} + * @param deviceId The device index. + */ + cpu(deviceId = 0): DLDevice { + return this.device("cpu", deviceId); + } + + /** + * Create a new webgpu {@link DLDevice} + * @param deviceId The device index. + */ + webgpu(deviceId = 0): DLDevice { + return this.device("webgpu", deviceId); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + dev: DLDevice = this.device("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + dev.deviceType, + dev.deviceId, + outPtr + ) + ); + const ret = this.ctx.attachToCurrentScope( + new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx) + ); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Create am uniform {@link NDArray} with given shape. + * + * @param shape The shape of the array. + * @param low The low value. + * @param high The high value. + * @param dev The device of the ndarray. + * @returns The created ndarray. + */ + uniform( + shape: Array, + low: number, + high: number, + dev: DLDevice + ): NDArray { + const ret = this.empty(shape, "float32", dev); + const size = shape.reduce((a, b) => { + return a * b; + }, 1); + const scale = high - low; + const input = new Float32Array(size); + for (let i = 0; i < input.length; ++i) { + input[i] = low + Math.random() * scale; + } + return ret.copyFrom(input); + } + + /** + * Sample index via top-p sampling. + * + * @param logits The input logits before normalization. + * @param temperature The temperature factor, will take argmax if temperature = 0.0 + * @param top_p The top_p + * @returns The sampled index. + */ + sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number): number { + return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random()); + } + + /** + * Bind canvas to the current WebGPU context + * @param canvas The canvas. + */ + bindCanvas(canvas: HTMLCanvasElement) { + this.lib.webGPUContext?.bindCanvas(canvas); + } + + /** + * Show image in canvas. + * + * @param dataRGBA Image array in height x width uint32 NDArray RGBA format on GPU. + */ + showImage(dataRGBA: NDArray) { + if (dataRGBA.shape.length != 2) { + throw Error("Require a height x width uint32 NDArray in RGBA" + + "get shape=" + dataRGBA.shape.toString() + " instead." + ); + } + if (dataRGBA.device.deviceType != DeviceStrToEnum.webgpu) { + throw new Error("Can only run showImage on WebGPU array, " + + "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead."); + } + if (dataRGBA.dtype != "uint32") { + throw Error("Require a height x width uint32 NDArray in RGBA, " + + "get " + dataRGBA.dtype + " instead."); + } + this.lib.webGPUContext?.drawImageFromBuffer( + dataRGBA.getDataPtr(), dataRGBA.shape[0], dataRGBA.shape[1] + ); + } + + /** + * Clear canvas + */ + clearCanvas() { + this.lib.webGPUContext?.clearCanvas(); + } + + /** + * Create an tuple {@link TVMArray} input array. + * + * The input array can be passed to tvm runtime function + * and needs to b explicitly disposed. + * + * @param inputs The input array + * @returns The result array. + */ + makeTVMArray( + inputs: Array + ): TVMArray { + return this.ctx.arrayMake(...inputs) as TVMArray; + } + + /** + * Create a shape tuple to pass to runtime. + * @param shape The shape . + * @returns The created shape tuple. + */ + makeShapeTuple(shape: Array): TVMObject { + const shapeArray = shape.map((value) => new Scalar(value, "int")); + return this.ctx.makeShapeTuple(...shapeArray); + } + /** + * Get type index from type key. + * @param typeKey The type key. + * @returns The corresponding type index. + */ + typeKey2Index( + typeKey: string + ): number { + const stack = this.lib.getOrAllocCallStack(); + const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1); + stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMObjectTypeKey2Index as ctypes.FTVMObjectTypeKey2Index)( + stack.ptrFromOffset(typeKeyOffset), + outPtr + ) + ); + const typeIndex = this.memory.loadU32(outPtr); + this.lib.recycleCallStack(stack); + return typeIndex; + } + + /** + * Register an object constructor. + * @param typeKey The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerObjectConstructor( + typeKey: string, + func: FObjectConstructor, + override = false + ): void { + const typeIndex = this.typeKey2Index(typeKey); + if (this.objFactory.has(typeIndex)) { + if (!override) { + throw new Error("Type " + typeKey + " already registered"); + } + } + this.objFactory.set(typeIndex, func); + } + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc( + name: string, + func: Function, + override = false + ): void { + const asyncVariant = (...args: Array): void => { + const fargs = args.slice(0, args.length - 1); + // need to keep it alive until callback is fulfilled. + const callback = this.detachFromCurrentScope(args[args.length - 1] as PackedFunc); + const promise: Promise = func(...fargs); + promise.then((rv: any) => { + callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + callback.dispose(); + }); + }; + this.registerFunc("__async." + name, asyncVariant, override); + } + + /** + * Asynchrously load webgpu pipelines when possible. + * @param mod The input module. + */ + async asyncLoadWebGPUPiplines(mod: Module): Promise { + if (this.lib.webGPUContext == undefined) throw Error("WebGPU not initialied"); + const webgpuContext = this.lib.webGPUContext; + + this.beginScope(); + const fmap_str = mod.getFunction("webgpu.get_fmap", true)() as string; + let fmap: Record = JSON.parse(fmap_str); + const totalFuncs = fmap.length; + const fGetShader = this.detachFromCurrentScope( + mod.getFunction("webgpu.get_shader") + ); + const fUpdatePrebuild = this.detachFromCurrentScope( + mod.getFunction("webgpu.update_prebuild") + ); + this.endScope(); + + const perf = compact.getPerformance(); + const tstart = perf.now(); + let tlastReport = tstart; + let finishCounter = 0; + const fmapEntries = Object.entries(fmap); + + let allEvents = Promise.resolve(); + + for (const [key, finfo] of fmapEntries) { + const code = fGetShader(key); + assert(key == finfo.name); + const event = webgpuContext.createShaderAsync(finfo, code).then((func: Function) => { + this.beginScope(); + fUpdatePrebuild(key, func); + this.endScope(); + + }).then(() => { + finishCounter += 1; + const tend = perf.now(); + const timeReportGap = 1000; + // skip report if gap is smaller than 1000 + if ((tend - tlastReport) < 1000 && finishCounter != fmapEntries.length) { + return; + } + tlastReport = tend; + const timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + // report + for (let j = 0; j < this.initProgressCallback.length; ++j) { + const progress = finishCounter / fmapEntries.length; + let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length + "]: "; + text += Math.floor(progress * 100).toString() + "% completed, " + text += timeElapsed + " secs elapsed."; + // this.initProgressCallback[j]({ + // progress: progress, + // timeElapsed: timeElapsed, + // text: text + // }); + } + }); + allEvents = Promise.all([allEvents, event]).then(() => { }); + } + await allEvents; + assert(finishCounter == fmapEntries.length); + } + + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void { + const webGPUContext = new WebGPUContext( + this.memory, device + ); + this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { + return webGPUContext.getDeviceAPI(name); + }); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, code: string) => { + const finfo = JSON.parse(info) as FunctionInfo; + return webGPUContext.createShader(finfo, code); + }); + this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + this.lib.webGPUContext = webGPUContext; + } + + /** Register all object factory */ + private registerObjectFactoryFuncs(): void { + this.registerObjectConstructor("Array", + (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { + return new TVMArray(handle, lib, ctx); + }); + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + const perf = compact.getPerformance(); + + // Helper function to time the finvoke + const timeExecution = async ( + finvoke: PackedFunc, + dev: DLDevice, + nstep: number, + repeat: number, + minRepeatMs: number, + limitZeroTimeIterations: number, + cooldownIntervalMs: number, + repeatsToCooldown: number + ): Promise => { + // detach and explicit dispose when tasks is fullfilled + // the promise will immediately return and we need to makesure + // finvoke do not get recycled. + this.ctx.detachFromCurrentScope(finvoke); + + finvoke(this.scalar(1, "int32")); + await dev.sync(); + const result = []; + let setupNumber: number = nstep; + + for (let i = 0; i < repeat; ++i) { + let durationMs = 0.0; + let absoluteZeroTimes = 0; + do { + if (durationMs > 0.0) { + let golden_ratio = 1.618; + setupNumber = Math.floor( + Math.max(minRepeatMs / (durationMs / setupNumber) + 1, setupNumber * golden_ratio) + ); + } + const tstart: number = perf.now(); + finvoke(this.scalar(setupNumber, "int32")); + await dev.sync(); + const tend: number = perf.now(); + + durationMs = tend - tstart; + if (durationMs == 0) { + absoluteZeroTimes++; + } + } while (durationMs < minRepeatMs && absoluteZeroTimes < limitZeroTimeIterations); + const speed = durationMs / setupNumber / 1000; + result.push(speed); + if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0) { + await new Promise(r => setTimeout(r, cooldownIntervalMs)); + } + } + const ret = new Float64Array(result.length); + ret.set(result); + + // dispose finvoke + finvoke.dispose(); + return new Uint8Array(ret.buffer); + }; + + const addOne = async (x: number): Promise => { + await new Promise(resolve => setTimeout(resolve, 100)); + return x + 1; + }; + + this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); + this.registerAsyncServerFunc("testing.asyncAddOne", addOne); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + if (!val.isView) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + } else { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle); + } + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); + } + } else if (val instanceof DLDevice) { + stack.storeI32(valueOffset, val.deviceType); + stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); + stack.storeI32(codeOffset, ArgTypeCode.DLDevice); + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, ArgTypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFuncInternal(val, false); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + } else if (val instanceof TVMObject) { + stack.storePtr(valueOffset, val.getHandle()); + stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + // use scope to track js values. + this.ctx.beginScope(); + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == ArgTypeCode.TVMObjectHandle || + tcode == ArgTypeCode.TVMObjectRValueRefArg || + tcode == ArgTypeCode.TVMPackedFuncHandle || + tcode == ArgTypeCode.TVMNDArrayHandle || + tcode == ArgTypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); + } + + const rv = func(...jsArgs); + // recycle all js object value in function unless we want to retain them. + this.ctx.endScope(); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + cell.getHandle(), + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + /** + * Creaye return value of the packed func. The value us auto-tracked for dispose. + * @param rvaluePtr The location of rvalue + * @param tcode The type code. + * @param callbackArg Whether it is being used in callbackArg. + * @returns The JS value. + */ + private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { + switch (tcode) { + case ArgTypeCode.Int: + case ArgTypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case ArgTypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case ArgTypeCode.TVMOpaqueHandle: { + return this.memory.loadPointer(rvaluePtr); + } + case ArgTypeCode.TVMNDArrayHandle: { + return this.ctx.attachToCurrentScope( + new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx) + ); + } + case ArgTypeCode.TVMDLTensorHandle: { + assert(callbackArg); + // no need to attach as we are only looking at view + return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib, this.ctx); + } + case ArgTypeCode.TVMPackedFuncHandle: { + return this.ctx.attachToCurrentScope( + this.makePackedFunc(this.memory.loadPointer(rvaluePtr)) + ); + } + case ArgTypeCode.TVMModuleHandle: { + return this.ctx.attachToCurrentScope( + new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.ctx.attachToCurrentScope(this.makePackedFunc(ptr)); + } + ) + ); + } + case ArgTypeCode.TVMObjectHandle: { + const obj = new TVMObject( + this.memory.loadPointer(rvaluePtr), + this.lib, + this.ctx + ); + const func = this.objFactory.get(obj.typeIndex()) + if (func != undefined) { + return this.ctx.attachToCurrentScope( + func(obj.getHandle(), this.lib, this.ctx) + ); + } else { + return this.ctx.attachToCurrentScope(obj); + } + } + case ArgTypeCode.Null: return undefined; + case ArgTypeCode.DLDevice: { + const deviceType = this.memory.loadI32(rvaluePtr); + const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); + return this.device(deviceType, deviceId); + } + case ArgTypeCode.TVMStr: { + const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + return ret; + } + case ArgTypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {}, + logger: (msg: string) => void = console.log +): Promise { + const env = new Environment(importObject, logger); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/packages/headless/src/worker/lib/tvm/support.d.ts b/packages/headless/src/worker/lib/tvm/support.d.ts new file mode 100644 index 0000000..d5d1b3e --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/support.d.ts @@ -0,0 +1,23 @@ +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export declare function StringToUint8Array(str: string): Uint8Array; +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export declare function Uint8ArrayToString(arr: Uint8Array): string; +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export declare function assert(condition: boolean, msg?: string): asserts condition; +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export declare function wasmPath(): string; diff --git a/packages/headless/src/worker/lib/tvm/support.ts b/packages/headless/src/worker/lib/tvm/support.ts new file mode 100644 index 0000000..7e2b6e6 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/support.ts @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of Array.from(arr)) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return "." + // return __dirname + "/wasm"; +} diff --git a/packages/headless/src/worker/lib/tvm/types.d.ts b/packages/headless/src/worker/lib/tvm/types.d.ts new file mode 100644 index 0000000..c8986c5 --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/types.d.ts @@ -0,0 +1,33 @@ +/** Common type definitions. */ +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/packages/headless/src/worker/lib/tvm/types.ts b/packages/headless/src/worker/lib/tvm/types.ts new file mode 100644 index 0000000..621375a --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/packages/headless/src/worker/lib/tvm/webgpu.d.ts b/packages/headless/src/worker/lib/tvm/webgpu.d.ts new file mode 100644 index 0000000..1a53f4e --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/webgpu.d.ts @@ -0,0 +1,124 @@ +/// +import { Memory } from "./memory"; +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; +export interface GPUDeviceDetectOutput { + adapter: GPUAdapter; + adapterInfo: GPUAdapterInfo; + device: GPUDevice; +} +/** + * DetectGPU device in the environment. + */ +export declare function detectGPUDevice(): Promise; +/** + * Function info from the API + */ +export interface FunctionInfo { + name: string; + arg_types: Array; + launch_param_tags: Array; +} +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export declare class WebGPUContext { + device: GPUDevice; + memory: Memory; + private bufferTable; + private bufferTableFreeId; + private podArgStagingBuffers; + private canvasRenderManager?; + private maxNumPodArgsStagingBuffers; + private peakAllocatedBytes; + private currAllocatedBytes; + private allAllocatedBytes; + private shaderSubmitCounter; + protected debugShaderSubmitLimit: number; + protected debugLogFinish: boolean; + constructor(memory: Memory, device: GPUDevice); + /** + * Dispose context. + */ + dispose(): void; + /** + * Wait for all pending GPU tasks to complete + */ + sync(): Promise; + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string; + /** + * Draw image from data in storage buffer. + * @param ptr The GPU ptr + * @param height The height of the image. + * @param width The width of the image. + */ + drawImageFromBuffer(ptr: GPUPointer, height: number, width: number): void; + /** + * Copy raw bytes into buffer ptr. + * + * @param rawBytes The raw bytes + * @param toPtr The target gpu buffer ptr + * @param toOffset The beginning offset + * @param nbytes Number of bytes + */ + copyRawBytesToBuffer(rawBytes: Uint8Array, toPtr: GPUPointer, toOffset: number, nbytes: number): void; + /** + * Clear canvas + */ + clearCanvas(): void; + /** + * Bind a canvas element to the runtime. + * @param canvas The HTML canvas/ + */ + bindCanvas(canvas: HTMLCanvasElement): void; + /** + * Create a PackedFunc that runs the given shader + * via createComputePipeline + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + createShader(finfo: FunctionInfo, code: string): Function; + /** + * Create a PackedFunc that runs the given shader asynchrously + * via createComputePipelineAsync + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + createShaderAsync(finfo: FunctionInfo, code: string): Promise; + /** + * Get the pod arg staging buffer + * \param nbytes The minimum size. + * \return The allocated buffer + */ + private getPodArgsBuffer; + /** + * Internal impl of createShader for both async and sync mode. + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @param asyncMode Whether use async mode. + * @returns The shader function or promise of shader func. + */ + private createShadeInternal; + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function; + private deviceAllocDataSpace; + private deviceFreeDataSpace; + private deviceCopyToGPU; + private deviceCopyFromGPU; + private deviceCopyWithinGPU; + private gpuBufferFromPtr; + private attachToBufferTable; +} diff --git a/packages/headless/src/worker/lib/tvm/webgpu.ts b/packages/headless/src/worker/lib/tvm/webgpu.ts new file mode 100644 index 0000000..16b269f --- /dev/null +++ b/packages/headless/src/worker/lib/tvm/webgpu.ts @@ -0,0 +1,862 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import { Pointer } from "./ctypes"; +import { Memory } from "./memory"; +import { assert } from "./support"; +import { Disposable } from "./types"; + + + +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; + +export interface GPUDeviceDetectOutput { + adapter: GPUAdapter; + adapterInfo: GPUAdapterInfo; + device: GPUDevice; +} + +/** + * DetectGPU device in the environment. + */ +export async function detectGPUDevice(): Promise { + if (typeof navigator !== "undefined" && navigator.gpu !== undefined) { + const adapter = await navigator.gpu.requestAdapter({ "powerPreference": "high-performance" }); + if (adapter == null) { + throw Error("Cannot find adapter that matches the request"); + } + const computeMB = (value: number) => { + return Math.ceil(value / (1 << 20)) + "MB"; + } + + // more detailed error message + const requiedMaxBufferSize = 1 << 30; + if (requiedMaxBufferSize > adapter.limits.maxBufferSize) { + throw Error( + `Cannot initialize runtime because of requested maxBufferSize ` + + `exceeds limit. requested=${computeMB(requiedMaxBufferSize)}, ` + + `limit=${computeMB(adapter.limits.maxBufferSize)}. ` + + `This error may be caused by an older version of the browser (e.g. Chrome 112). ` + + `You can try to upgrade your browser to Chrome 113 or later.` + ); + } + + const requiredMaxStorageBufferBindingSize = 1 << 30; + if (requiredMaxStorageBufferBindingSize > adapter.limits.maxStorageBufferBindingSize) { + throw Error( + `Cannot initialize runtime because of requested maxStorageBufferBindingSize ` + + `exceeds limit. requested=${computeMB(requiredMaxStorageBufferBindingSize)}, ` + + `limit=${computeMB(adapter.limits.maxStorageBufferBindingSize)}. ` + ); + } + + const requiredMaxComputeWorkgroupStorageSize = 32 << 10; + if (requiredMaxComputeWorkgroupStorageSize > adapter.limits.maxComputeWorkgroupStorageSize) { + throw Error( + `Cannot initialize runtime because of requested maxComputeWorkgroupStorageSize ` + + `exceeds limit. requested=${requiredMaxComputeWorkgroupStorageSize}, ` + + `limit=${adapter.limits.maxComputeWorkgroupStorageSize}. ` + ); + } + + const adapterInfo = await adapter.requestAdapterInfo(); + const device = await adapter.requestDevice({ + requiredLimits: { + maxBufferSize: requiedMaxBufferSize, + maxStorageBufferBindingSize: requiredMaxStorageBufferBindingSize, + maxComputeWorkgroupStorageSize: requiredMaxComputeWorkgroupStorageSize, + } + }); + return { + adapter: adapter, + adapterInfo: adapterInfo, + device: device + }; + } else { + return undefined; + } +} + +const canvasRenderWGSL = ` +@group(0) @binding(0) var my_sampler : sampler; +@group(0) @binding(1) var my_texture : texture_2d; + +struct VertexOutput { + @builtin(position) position : vec4, + @location(0) uv : vec2, +} + +@vertex +fn vertex_main(@builtin(vertex_index) vidx : u32) -> VertexOutput { + const pos = array( + vec2( 1.0, 1.0), + vec2( 1.0, -1.0), + vec2(-1.0, -1.0), + vec2( 1.0, 1.0), + vec2(-1.0, -1.0), + vec2(-1.0, 1.0), + ); + + const uv = array( + vec2(1.0, 0.0), + vec2(1.0, 1.0), + vec2(0.0, 1.0), + vec2(1.0, 0.0), + vec2(0.0, 1.0), + vec2(0.0, 0.0), + ); + + var output : VertexOutput; + output.position = vec4(pos[vidx], 0.0, 1.0); + output.uv = uv[vidx]; + return output; +} + +@fragment +fn fragment_main(@location(0) uv : vec2) -> @location(0) vec4 { + return textureSample(my_texture, my_sampler, uv); +} + +@fragment +fn fragment_clear(@location(0) uv : vec2) -> @location(0) vec4 { + return vec4(1.0, 1.0, 1.0, 1.0); +} +` +class CanvaRenderManager implements Disposable { + private device: GPUDevice; + private canvasContext: GPUCanvasContext; + private stagingTexture: GPUTexture; + private renderSampler: GPUSampler; + private renderPipeline: GPURenderPipeline; + private clearPipeline: GPURenderPipeline; + private canvasTextureFormat: GPUTextureFormat; + + constructor(device: GPUDevice, canvas: HTMLCanvasElement) { + this.device = device; + const ctx = canvas.getContext("webgpu"); + if (ctx == null) { + throw Error("Cannot bind WebGPU context"); + } + // @ts-ignore + this.canvasContext = ctx; + this.canvasTextureFormat = navigator.gpu.getPreferredCanvasFormat(); + this.canvasContext.configure({ + device: this.device, + format: this.canvasTextureFormat, + alphaMode: "opaque", + }); + + this.renderPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_main", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + + this.clearPipeline = device.createRenderPipeline({ + layout: "auto", + vertex: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "vertex_main", + }, + fragment: { + module: device.createShaderModule({ + code: canvasRenderWGSL, + }), + entryPoint: "fragment_clear", + targets: [{ + format: this.canvasTextureFormat, + }], + }, + primitive: { + topology: "triangle-list", + }, + }); + + this.renderSampler = device.createSampler({ + magFilter: "linear", + minFilter: "linear", + }); + // staging texture always be in RGBA + this.stagingTexture = device.createTexture({ + size: [canvas.height, canvas.width, 1], + format: "rgba8unorm", + usage: + GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + + clear() { + const commandEncoder = this.device.createCommandEncoder(); + const passEncoder = commandEncoder.beginRenderPass({ + //@ts-ignore + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.clearPipeline); + const renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + } + + draw(buffer: GPUBuffer, height: number, width: number) { + // resize the staging texture + if (height != this.stagingTexture.height || width != this.stagingTexture.width) { + this.stagingTexture.destroy(); + this.stagingTexture = this.device.createTexture({ + size: [height, width, 1], + format: "rgba8unorm", + usage: + GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT, + }); + } + + const commandEncoder = this.device.createCommandEncoder(); + commandEncoder.copyBufferToTexture({ + buffer: buffer, + offset: 0, + bytesPerRow: this.stagingTexture.width * 4 + }, { + texture: this.stagingTexture + }, { + width: this.stagingTexture.width, + height: this.stagingTexture.height + }); + + const passEncoder = commandEncoder.beginRenderPass({ + //@ts-ignore + colorAttachments: [ + { + view: this.canvasContext.getCurrentTexture().createView(), + clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 }, + loadOp: "clear", + storeOp: "store", + }, + ], + }); + passEncoder.setPipeline(this.renderPipeline); + const renderBindingGroup = this.device.createBindGroup({ + layout: this.renderPipeline.getBindGroupLayout(0), + entries: [ + { binding: 0, resource: this.renderSampler }, + { binding: 1, resource: this.stagingTexture.createView() }, + ], + }); + passEncoder.setBindGroup(0, renderBindingGroup); + passEncoder.draw(6, 1, 0, 0); + passEncoder.end(); + this.device.queue.submit([commandEncoder.finish()]); + } + + dispose(): void { + this.stagingTexture.destroy(); + } +} + +/** + * Function info from the API + */ +export interface FunctionInfo { + name: string; + arg_types: Array; + launch_param_tags: Array; +} + +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export class WebGPUContext { + device: GPUDevice; + memory: Memory; + // internal data + private bufferTable: Array = [undefined]; + private bufferTableFreeId: Array = []; + private podArgStagingBuffers: Array = []; + private canvasRenderManager?: CanvaRenderManager = undefined; + // number of pod arg staging buffers + private maxNumPodArgsStagingBuffers: number = 2; + // flags for debugging + // stats of the runtime. + // peak allocation + private peakAllocatedBytes: number = 0; + // current allocation + private currAllocatedBytes: number = 0; + // all allocation(ignoring free) + private allAllocatedBytes: number = 0; + // shader submit counter + private shaderSubmitCounter: number = 0; + // limite number of shaders to be submitted, useful for debugging, default to -1 + protected debugShaderSubmitLimit: number = -1; + // log and sync each step + protected debugLogFinish: boolean = false; + + constructor(memory: Memory, device: GPUDevice) { + this.memory = memory; + this.device = device; + } + + /** + * Dispose context. + */ + dispose() { + this.canvasRenderManager?.dispose(); + this.bufferTableFreeId = []; + while (this.bufferTable.length != 0) { + this.bufferTable.pop()?.destroy(); + } + while (this.podArgStagingBuffers.length != 0) { + this.podArgStagingBuffers.pop()?.destroy(); + } + this.device.destroy(); + } + + /** + * Wait for all pending GPU tasks to complete + */ + async sync(): Promise { + await this.device.queue.onSubmittedWorkDone(); + } + + /** + * Obtain the runtime information in readable format. + */ + runtimeStatsText(): string { + let info = "peak-memory=" + Math.ceil(this.peakAllocatedBytes / (1 << 20)) + " MB"; + info += ", all-memory=" + Math.ceil(this.allAllocatedBytes / (1 << 20)) + " MB"; + info += ", shader-submissions=" + this.shaderSubmitCounter; + return info; + } + + /** + * Draw image from data in storage buffer. + * @param ptr The GPU ptr + * @param height The height of the image. + * @param width The width of the image. + */ + drawImageFromBuffer(ptr: GPUPointer, height: number, width: number) { + if (this.canvasRenderManager == undefined) { + throw Error("Do not have a canvas context, call bindCanvas first"); + } + this.canvasRenderManager.draw(this.gpuBufferFromPtr(ptr), height, width); + } + + /** + * Copy raw bytes into buffer ptr. + * + * @param rawBytes The raw bytes + * @param toPtr The target gpu buffer ptr + * @param toOffset The beginning offset + * @param nbytes Number of bytes + */ + copyRawBytesToBuffer( + rawBytes: Uint8Array, + toPtr: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + this.device.queue.writeBuffer( + this.gpuBufferFromPtr(toPtr), + toOffset, + rawBytes, + 0, + nbytes + ); + } + /** + * Clear canvas + */ + clearCanvas() { + this.canvasRenderManager?.clear(); + } + + /** + * Bind a canvas element to the runtime. + * @param canvas The HTML canvas/ + */ + bindCanvas(canvas: HTMLCanvasElement) { + this.canvasRenderManager = new CanvaRenderManager(this.device, canvas); + } + + /** + * Create a PackedFunc that runs the given shader + * via createComputePipeline + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + createShader(finfo: FunctionInfo, code: string): Function { + return this.createShadeInternal(finfo, code, false) as Function; + } + + /** + * Create a PackedFunc that runs the given shader asynchrously + * via createComputePipelineAsync + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @returns The shader + */ + async createShaderAsync(finfo: FunctionInfo, code: string): Promise { + return await (this.createShadeInternal(finfo, code, true) as Promise); + } + + /** + * Get the pod arg staging buffer + * \param nbytes The minimum size. + * \return The allocated buffer + */ + private getPodArgsBuffer(nbytes: number): GPUBuffer { + let buffer: GPUBuffer | undefined = undefined; + if (this.podArgStagingBuffers.length >= this.maxNumPodArgsStagingBuffers) { + buffer = this.podArgStagingBuffers.shift(); + } + // minimum of 16 bytes + let allocSize = 16; + if (buffer !== undefined) { + allocSize = buffer.size; + if (buffer.size < nbytes) { + buffer.destroy(); + buffer = undefined; + } + } + while (allocSize < nbytes) { + allocSize *= 2; + } + + if (buffer == undefined) { + // create uniform buffer + buffer = this.device.createBuffer({ + size: allocSize, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, + }); + } + assert(nbytes <= buffer.size); + return buffer; + } + + /** + * Internal impl of createShader for both async and sync mode. + * + * @param info The function information already parsed as a record. + * @param code The shader data(in WGSL) + * @param asyncMode Whether use async mode. + * @returns The shader function or promise of shader func. + */ + private createShadeInternal( + finfo: FunctionInfo, + code: string, + asyncMode: boolean + ): Function | Promise { + const dispatchToDim: Array = []; + let paramWriteAccess: Array = []; + + for (let i = 0; i < finfo.launch_param_tags.length; ++i) { + const tag: string = finfo.launch_param_tags[i]; + if (tag.startsWith("blockIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target); + } else if (tag.startsWith("threadIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target + 3); + } else if (tag.startsWith("paramWriteAccess:")) { + paramWriteAccess = JSON.parse(tag.substring(17)); + } else { + throw new Error("Cannot handle thread_axis " + tag); + } + } + + + const layoutEntries: Array = []; + const bufferArgIndices: Array = []; + const podArgIndices: Array = []; + + for (let i = 0; i < finfo.arg_types.length; ++i) { + const dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ + binding: bufferArgIndices.length, + visibility: GPUShaderStage.COMPUTE, + buffer: { + type: paramWriteAccess[bufferArgIndices.length] ? "storage" : "read-only-storage" + } + }); + bufferArgIndices.push(i); + } else if (dtype.startsWith("int") || dtype.startsWith("uint") || dtype.startsWith("float")) { + podArgIndices.push(i); + } else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); + } + } + + assert(paramWriteAccess.length == bufferArgIndices.length); + // POD arguments are pass in the end + layoutEntries.push({ + binding: bufferArgIndices.length, + visibility: GPUShaderStage.COMPUTE, + buffer: { + type: "uniform" + } + }); + + const bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + const pipelineLayout = this.device.createPipelineLayout({ + bindGroupLayouts: [bindGroupLayout] + }); + + // Function to create the pipeline. + const createShaderFunc = (pipeline: GPUComputePipeline): Function => { + const submitShader = (...args: Array): void => { + if (this.debugShaderSubmitLimit != -1 && + this.shaderSubmitCounter >= this.debugShaderSubmitLimit) { + this.shaderSubmitCounter += 1; + return; + } + + const commandEncoder = this.device.createCommandEncoder(); + const compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + const bindGroupEntries: Array = []; + const numBufferOrPodArgs = bufferArgIndices.length + podArgIndices.length; + + assert(args.length == numBufferOrPodArgs + dispatchToDim.length); + + const workDim: Array = [1, 1, 1, 1, 1, 1]; + for (let i = 0; i < dispatchToDim.length; ++i) { + workDim[dispatchToDim[i]] = args[numBufferOrPodArgs + i]; + } + + // get around 65535 restriction of blockIdx.x + if (workDim[2] != 1) { + throw Error("WebGPU: blockIdx.z is reserved for internal use"); + } + const packDimX = workDim[0]; + // spread thinsg out into blockIdx.z + if (workDim[0] >= (1 << 16)) { + let wl_x = workDim[0]; + let wl_z = workDim[2]; + + while (wl_x >= (1 << 16)) { + if (wl_x % 2 == 0) { + wl_x = wl_x / 2; + } else { + // pad up + wl_x = (wl_x + 1) / 2; + } + wl_z *= 2; + } + workDim[0] = wl_x; + workDim[2] = wl_z; + assert(wl_x * wl_z >= packDimX); + } + + for (let i = 0; i < bufferArgIndices.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: this.gpuBufferFromPtr(args[bufferArgIndices[i]]) + } + }); + } + + // push pod buffer + const sizeOfI32 = 4; + const podArgBuffer = this.getPodArgsBuffer((podArgIndices.length + 1) * sizeOfI32); + const i32View = new Int32Array(podArgIndices.length + 1); + const u32View = new Uint32Array(i32View.buffer); + const f32View = new Float32Array(i32View.buffer); + + for (let i = 0; i < podArgIndices.length; ++i) { + const value = args[podArgIndices[i]]; + const dtype = finfo.arg_types[podArgIndices[i]]; + if (dtype.startsWith("int")) { + i32View[i] = value; + } else if (dtype.startsWith("uint")) { + u32View[i] = value; + } else if (dtype.startsWith("float")) { + f32View[i] = value; + } else { + throw Error("Unknown pod dtype " + dtype); + } + } + // always pass in dim z launching grid size in + u32View[podArgIndices.length] = packDimX; + this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer); + + bindGroupEntries.push({ + binding: bufferArgIndices.length, + resource: { + buffer: podArgBuffer, + size: i32View.buffer.byteLength + } + }); + + compute.setBindGroup(0, this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + + compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2]) + compute.end() + const command = commandEncoder.finish(); + this.device.queue.submit([command]); + + if (this.debugLogFinish) { + const currCounter = this.shaderSubmitCounter; + this.device.queue.onSubmittedWorkDone().then(() => { + // console.log("[" + currCounter + "][Debug] finish shader" + finfo.name); + }); + } + this.shaderSubmitCounter += 1; + }; + return submitShader; + }; + + const shaderModule = this.device.createShaderModule({ + code: code, + hints: { + main: { + layout: pipelineLayout + } + } + }); + + if (asyncMode) { + return this.device.createComputePipelineAsync({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }).then((pipeline: GPUComputePipeline) => { + return createShaderFunc(pipeline); + }); + } else { + const pipeline = this.device.createComputePipeline({ + layout: pipelineLayout, + compute: { + module: shaderModule, + entryPoint: finfo.name + } + }); + return createShaderFunc(pipeline); + } + } + + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function { + if (name == "deviceAllocDataSpace") { + return (nbytes: number): GPUPointer => { + return this.deviceAllocDataSpace(nbytes); + }; + } else if (name == "deviceFreeDataSpace") { + return (ptr: GPUPointer): void => { + return this.deviceFreeDataSpace(ptr); + }; + } else if (name == "deviceCopyToGPU") { + return ( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyToGPU(from, to, toOffset, nbytes); + }; + } else if (name == "deviceCopyFromGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void => { + this.deviceCopyFromGPU(from, fromOffset, to, nbytes); + }; + } else if (name == "deviceCopyWithinGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyWithinGPU(from, fromOffset, to, toOffset, nbytes); + }; + } else { + throw new Error("Unknown DeviceAPI function " + name); + } + } + + // DeviceAPI + private deviceAllocDataSpace(nbytes: number): GPUPointer { + // allocate 0 bytes buffer as 1 bytes buffer. + if (nbytes == 0) { + nbytes = 1; + } + const buffer = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + this.currAllocatedBytes += nbytes; + this.allAllocatedBytes += nbytes; + if (this.currAllocatedBytes > this.peakAllocatedBytes) { + this.peakAllocatedBytes = this.currAllocatedBytes; + } + const ptr = this.attachToBufferTable(buffer); + return ptr; + } + + private deviceFreeDataSpace(ptr: GPUPointer): void { + const idx = ptr; + const buffer = this.bufferTable[idx]; + this.bufferTable[idx] = undefined; + assert(buffer !== undefined); + this.bufferTableFreeId.push(idx); + this.currAllocatedBytes -= buffer.size; + buffer.destroy(); + } + + private deviceCopyToGPU( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + const rawBytes = this.memory.loadRawBytes(from, nbytes); + this.device.queue.writeBuffer( + this.gpuBufferFromPtr(to), + toOffset, + rawBytes, + 0, + nbytes + ); + } + + private deviceCopyFromGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void { + // Perhaps it would be more useful to resuse a staging buffer? + const gpuTemp = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + gpuTemp, + 0, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.queue.submit([copyCommands]); + + gpuTemp.mapAsync(GPUMapMode.READ).then(() => { + const data = gpuTemp.getMappedRange(); + this.memory.storeRawBytes(to, new Uint8Array(data)); + gpuTemp.destroy(); + }); + } + + private deviceCopyWithinGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void { + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.queue.submit([copyCommands]); + } + + private gpuBufferFromPtr(ptr: GPUPointer): GPUBuffer { + const buffer = this.bufferTable[ptr]; + assert(buffer !== undefined); + return buffer; + } + + private attachToBufferTable(buffer: GPUBuffer): GPUPointer { + if (this.bufferTableFreeId.length != 0) { + const idx = this.bufferTableFreeId.pop() as number; + this.bufferTable[idx] = buffer; + return idx; + } else { + const idx = this.bufferTable.length; + this.bufferTable.push(buffer); + return idx; + } + } +} diff --git a/packages/headless/src/worker/llm.ts b/packages/headless/src/worker/llm.ts new file mode 100644 index 0000000..180e640 --- /dev/null +++ b/packages/headless/src/worker/llm.ts @@ -0,0 +1,337 @@ +import { v4 as uuidv4 } from "uuid"; +import { Conversation } from "../types/chat"; +import { GenerateTextCallback, GenerateTextRequest } from "../types/worker_message"; +import { detectGPUDevice, instantiate } from "./lib/tvm"; +import { InitProgressCallback } from "./lib/tvm/runtime"; +import { Config } from "./worker"; + +export class LLMInstance { + config: Config; + tvm: any; + tokenizer: any; + model: any; + spp: any; + processing: boolean; + + constructor(config: Config, sentencePieceProcessor: any) { + this.config = config; + this.tvm = undefined; + this.tokenizer = undefined; + this.model = undefined; + this.spp = sentencePieceProcessor; + this.processing = false; + } + + isInitialized() { + return this.model != undefined; + } + + async init(cb: InitProgressCallback) { + if (this.model) { + return; + } + const wasmSource = await (await fetch(this.config.wasmUrl)).arrayBuffer(); + this.tvm = await instantiate( + new Uint8Array(wasmSource), + //@ts-ignore + new EmccWASI(), + console.log, + ); + try { + const output = await detectGPUDevice(); + if (output !== undefined) { + this.tvm.initWebGPU(output.device); + } else { + throw Error("This browser env do not support WebGPU"); + } + } catch (err: any) { + throw Error("Find an error initializing WebGPU: " + err.toString()); + } + this.tvm.registerInitProgressCallback(cb); + await this.tvm.fetchNDArrayCache(this.config.cacheUrl, this.tvm.webgpu()); + + this.tokenizer = await this.spp()(this.config.tokenizerUrl); + this.model = this.tvm.withNewScope(() => { + return new LLMInstanceScope( + this.tvm, + this.tokenizer, + this.config.maxWindowSize + ); + }); + return this.model.init(); + } + + async generate(request: GenerateTextRequest, cb: GenerateTextCallback) { + if (this.processing) { + return; + } + this.processing = true; + await this.model.generate(request, cb); + this.processing = false; + } +} + +export class LLMInstanceScope { + tvm: any; + tokenizer: any; + maxWindowSize: number; + device: any; + vm: any; + encoding: any; + decoding: any; + params: any; + bosTokenId: number; + eosTokenId: number; + fclearKVCaches: any; + kvCache: any; + fcreateCache: any; + logitsOnCPU: any; + kvCacheLength: number; + lastMessageId: string; + + constructor(tvm: any, tokenizer: any, maxWindowSize = 2048) { + this.tvm = tvm; + this.tokenizer = tokenizer; + + this.bosTokenId = 1; + this.eosTokenId = 2; + + this.maxWindowSize = maxWindowSize; + + this.device = this.tvm.webgpu(); + + this.vm = this.tvm.detachFromCurrentScope( + this.tvm.createVirtualMachine(this.device) + ); + this.encoding = this.tvm.detachFromCurrentScope( + this.vm.getFunction("encoding") + ); + this.decoding = this.tvm.detachFromCurrentScope( + this.vm.getFunction("decoding") + ); + this.params = this.tvm.detachFromCurrentScope( + this.tvm.getParamsFromCache("param", this.tvm.cacheMetadata.ParamSize) + ); + const fcreateCache = this.vm.getFunction("create_kv_cache"); + this.fclearKVCaches = this.tvm.detachFromCurrentScope( + this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear") + ); + + // use extern config for now + this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); + // fill with pad token + this.logitsOnCPU = undefined; + + this.kvCacheLength = 0; + this.lastMessageId = ""; + } + + async init() { + await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); + } + + async getTokensFromStart(conversation: Conversation, maxTokens: number) { + this.clearKVCache(); + const tokens = []; + + for (let i = conversation.messages.length - 1; i >= 0; i--) { + const message = conversation.messages[i]; + const text = `${message.role}: ${message.text}\n`; + const messageTokens = await this.tokenizer.encodeIds(text); + if ( + tokens.length + messageTokens.length + maxTokens > + this.maxWindowSize + ) { + break; + } + tokens.unshift(...(await this.tokenizer.encodeIds(text))); + } + tokens.unshift( + ...(await this.tokenizer.encodeIds(conversation.systemPrompt)) + ); + tokens.unshift(this.bosTokenId); + + return tokens; + } + + async getTokens(conversation: Conversation, maxTokens: number) { + // Case 1. Attention Cache is empty, start from beginning + // Case 2. Attention Cache is not empty, but the last message we processed is not in the cache, start from beginning + // Case 3. Attention Cache is not empty, and the last message we processed is in the cache, start from the next message + // Case 4. Attention Cache is not empty, and the last message we processed is in the cache, but the cache is too long, start from beginning + if (this.kvCacheLength == 0) { + // Case 1 + return await this.getTokensFromStart(conversation, maxTokens); + } + + // Calculate the index of the last message we processed + let startMsgIdx = 0; + for (let i = conversation.messages.length - 1; i >= 0; i--) { + if (conversation.messages[i].id == this.lastMessageId) { + startMsgIdx = i + 1; + break; + } + } + + if (startMsgIdx == 0) { + // Case 2 + return await this.getTokensFromStart(conversation, maxTokens); + } + + const tokens = [this.eosTokenId]; + for (let i = startMsgIdx; i < conversation.messages.length; i++) { + const message = conversation.messages[i]; + const text = `${message.role}: ${message.text}`; + const messageTokens = await this.tokenizer.encodeIds(text); + if ( + tokens.length + messageTokens.length + maxTokens > + this.maxWindowSize + ) { + // Case 4 + return await this.getTokensFromStart(conversation, maxTokens); + } + tokens.push(...(await this.tokenizer.encodeIds(text))); + } + + // Case 3 + return tokens; + } + + async generate(request: GenerateTextRequest, cb: GenerateTextCallback) { + const { conversation, maxTokens, assistantRoleName, stopTexts } = request; + const tokens = await this.getTokens(conversation, maxTokens); + tokens.push(...(await this.tokenizer.encodeIds(`${assistantRoleName}:`))); + console.log("debug: ", await this.tokenizer.decodeIds(tokens)); + + const inputTokenLength = tokens.length; + let outputText = ""; + let tstart = 0, + tend = 0, step = 0; + + const id = uuidv4(); + for (; step < maxTokens; step++) { + this.tvm.beginScope(); + tstart = performance.now(); + var input; + if (step == 0) { + input = this.tvm.empty([1, tokens.length], "int32", this.device); + input.copyFrom(tokens); + } else { + input = this.tvm.empty([1, 1], "int32", this.device); + input.copyFrom(tokens.slice(tokens.length - 1)); + } + const logits = this.tvm.detachFromCurrentScope( + this.forward(input, this.kvCacheLength + inputTokenLength + step) + ); + this.tvm.endScope(); + const nextToken = await this.sampleTokenFromLogits(logits); + logits.dispose(); + + tokens.push(nextToken); + const outputTokens = tokens.slice(inputTokenLength); + outputText = this.tokenizer.decodeIds(outputTokens); + tend = performance.now(); + if (nextToken == this.eosTokenId) break; + const stopPos = outputText.lastIndexOf(""); + if (stopPos != -1) { + outputText = outputText.substring(0, stopPos); + break; + } + let stop = false; + for (let i = 0; i < stopTexts.length; i++) { + if (outputText.endsWith(stopTexts[i])) { + outputText = outputText.substring( + 0, + outputText.length - stopTexts[i].length + ); + stop = true; + break; + } + } + if (stop) break; + if (step != 0) { + cb({ + requestId: id, + step: step, + outputText, + stats: { + totalDecodingSeconds: (tend - tstart) / 1000, + totalDecodedTokens: tokens.length - inputTokenLength, + totalEncodedTokens: inputTokenLength, + }, + isFinished: false, + }); + } + } + this.kvCacheLength += tokens.length - 1; + this.lastMessageId = id; + + cb({ + requestId: id, + outputText, + step: step, + stats: { + totalDecodingSeconds: (tend - tstart) / 1000, + totalDecodedTokens: tokens.length - inputTokenLength, + totalEncodedTokens: inputTokenLength, + }, + isFinished: true, + }); + } + + dispose() { + // note: tvm instance is not owned by this class + this.params.dispose(); + this.decoding.dispose(); + this.encoding.dispose(); + this.vm.dispose(); + this.kvCache.dispose(); + this.fclearKVCaches.dispose(); + if (this.logitsOnCPU != undefined) { + this.logitsOnCPU.dispose(); + } + } + + clearKVCache() { + this.fclearKVCaches(this.kvCache); + this.kvCacheLength = 0; + this.lastMessageId = ""; + } + + forward(inputs: any, curPos: number) { + this.tvm.beginScope(); + var retValue; + const seqLenShape = this.tvm.makeShapeTuple([curPos]); + if (inputs.shape[1] > 1) { + retValue = this.encoding(inputs, seqLenShape, this.kvCache, this.params); + } else { + retValue = this.decoding(inputs, seqLenShape, this.kvCache, this.params); + } + const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); + this.tvm.endScope(); + this.tvm.attachToCurrentScope(logits); + return logits; + } + + // NOTE: caller must call device.sync() + updateLogitsOnCPU(logits: any) { + if (this.logitsOnCPU == undefined) { + this.logitsOnCPU = this.tvm.detachFromCurrentScope( + this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) + ); + } else { + if (logits.shape[0] != this.logitsOnCPU.shape[0]) { + throw Error("We expect the size of logits to remain unchanged"); + } + } + this.logitsOnCPU.copyFrom(logits); + } + + async sampleTokenFromLogits(logits: any, temperature = 0.8, top_p = 0.95) { + this.tvm.beginScope(); + this.updateLogitsOnCPU(logits); + this.tvm.endScope(); + await this.device.sync(); + return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); + } +} diff --git a/packages/headless/src/worker/worker.ts b/packages/headless/src/worker/worker.ts new file mode 100644 index 0000000..769ae3b --- /dev/null +++ b/packages/headless/src/worker/worker.ts @@ -0,0 +1,54 @@ +import * as Comlink from "comlink"; +import { GenerateTextCallback, GenerateTextRequest, ModelWorker } from "../types/worker_message"; +import { InitProgressCallback } from '../worker/lib/tvm/runtime'; +import { LLMInstance } from '../worker/llm'; + +declare global { + var importScripts: (...url: string[]) => void; + var sentencepiece: { + sentencePieceProcessor: (url: string) => void; + }; +} + +const config = { + kvConfig: { + numLayers: 64, + shape: [32, 32, 128], + dtype: 'float32', + }, + wasmUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/vicuna-7b-v1_webgpu.wasm', + cacheUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/params/', + tokenizerUrl: 'https://huggingface.co/mrick/react-llm/resolve/main/models/vicuna-7b-v1/tokenizer.model', + sentencePieceJsUrl: 'https://cdn.matt-rickard.com/code/sentencepiece.js', + tvmRuntimeJsUrl: 'https://cdn.matt-rickard.com/code/tvmjs_runtime.wasi.js', + maxWindowSize: 2048, +} as Config; + +export type Config = { + kvConfig: { + numLayers: number; + shape: number[]; + dtype: string; + }; + wasmUrl: string; + cacheUrl: string; + tokenizerUrl: string; + sentencePieceJsUrl: string; + tvmRuntimeJsUrl: string; + maxWindowSize: number; +} +const instance = new LLMInstance(config, () => globalThis.sentencepiece.sentencePieceProcessor); +const worker = { + init(callback: Comlink.ProxyOrClone) { + instance.init(callback); + }, + generate(request: GenerateTextRequest, cb: Comlink.ProxyOrClone) { + instance.generate(request, cb); + } +} as ModelWorker; + +importScripts(...[ + config.sentencePieceJsUrl, config.tvmRuntimeJsUrl +]); + +Comlink.expose(worker); diff --git a/packages/headless/tsconfig.json b/packages/headless/tsconfig.json new file mode 100644 index 0000000..3aeccbc --- /dev/null +++ b/packages/headless/tsconfig.json @@ -0,0 +1,42 @@ +{ + "compilerOptions": { + "target": "es5", + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "noEmit": false, + "composite": true, + "allowJs": true, + "skipLibCheck": true, + "declaration": true, + "declarationDir": "./dist/types", + "strict": true, + "forceConsistentCasingInFileNames": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "react", + "types": [ + "@webgpu/types" + ], + "typeRoots": [ + "./node_modules/@types" + ], + "paths": { + "@/*": [ + "./src/*" + ], + }, + }, + "include": [ + "src/**/*.ts", + "src/**/*.tsx", + ], + "exclude": [ + "node_modules" + ] +} \ No newline at end of file diff --git a/packages/retro-ui/.eslintrc.json b/packages/retro-ui/.eslintrc.json new file mode 100644 index 0000000..bffb357 --- /dev/null +++ b/packages/retro-ui/.eslintrc.json @@ -0,0 +1,3 @@ +{ + "extends": "next/core-web-vitals" +} diff --git a/packages/retro-ui/.gitignore b/packages/retro-ui/.gitignore new file mode 100644 index 0000000..8f322f0 --- /dev/null +++ b/packages/retro-ui/.gitignore @@ -0,0 +1,35 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts diff --git a/packages/retro-ui/LICENSE b/packages/retro-ui/LICENSE new file mode 100644 index 0000000..052f695 --- /dev/null +++ b/packages/retro-ui/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Matt Rickard + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/retro-ui/README.md b/packages/retro-ui/README.md new file mode 100644 index 0000000..f4da3c4 --- /dev/null +++ b/packages/retro-ui/README.md @@ -0,0 +1,34 @@ +This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). + +## Getting Started + +First, run the development server: + +```bash +npm run dev +# or +yarn dev +# or +pnpm dev +``` + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. + +This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. + +## Learn More + +To learn more about Next.js, take a look at the following resources: + +- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. +- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. + +You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome! + +## Deploy on Vercel + +The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. + +Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. diff --git a/packages/retro-ui/next.config.js b/packages/retro-ui/next.config.js new file mode 100644 index 0000000..d95f1f8 --- /dev/null +++ b/packages/retro-ui/next.config.js @@ -0,0 +1,28 @@ +/** @type {import('next').NextConfig} */ +const nextConfig = { + transpilePackages: ["react95"], + webpack(config, options) { + const { isServer } = options; + config.module.rules.push({ + test: /\.(ogg|mp3|wav|mpe?g)$/i, + exclude: config.exclude, + use: [ + { + loader: require.resolve("url-loader"), + options: { + limit: config.inlineImageLimit, + fallback: require.resolve("file-loader"), + publicPath: `${config.assetPrefix}/_next/static/images/`, + outputPath: `${isServer ? "../" : ""}static/images/`, + name: "[name]-[hash].[ext]", + esModule: config.esModule || false, + }, + }, + ], + }); + + return config; + }, +}; + +module.exports = nextConfig; diff --git a/packages/retro-ui/package.json b/packages/retro-ui/package.json new file mode 100644 index 0000000..db3c8cb --- /dev/null +++ b/packages/retro-ui/package.json @@ -0,0 +1,39 @@ +{ + "name": "retro-ui", + "version": "0.0.1", + "private": true, + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start", + "lint": "next lint" + }, + "keywords": [ + "chatgpt", + "llm", + "headless", + "react" + ], + "dependencies": { + "@react-llm/headless": "workspace:*", + "@types/node": "20.1.3", + "@types/react": "18.2.6", + "@types/react-dom": "18.2.4", + "eslint": "8.40.0", + "eslint-config-next": "13.4.2", + "next": "13.4.2", + "react": "18.2.0", + "react-dom": "18.2.0", + "react95": "^4.0.0", + "typescript": "5.0.4", + "use-sound": "^4.0.1" + }, + "devDependencies": { + "@types/styled-components": "^5.1.26", + "autoprefixer": "10.4.14", + "file-loader": "^6.2.0", + "postcss": "^8.4.23", + "tailwindcss": "^3.3.2", + "url-loader": "^4.1.1" + } +} diff --git a/packages/retro-ui/postcss.config.js b/packages/retro-ui/postcss.config.js new file mode 100644 index 0000000..4df9d3b --- /dev/null +++ b/packages/retro-ui/postcss.config.js @@ -0,0 +1,6 @@ +module.exports = { + plugins: { + autoprefixer: {}, + tailwindcss: {}, + }, +}; diff --git a/packages/retro-ui/public/buddy88.gif b/packages/retro-ui/public/buddy88.gif new file mode 100644 index 0000000..fed6cd6 Binary files /dev/null and b/packages/retro-ui/public/buddy88.gif differ diff --git a/packages/retro-ui/public/favicon.ico b/packages/retro-ui/public/favicon.ico new file mode 100644 index 0000000..7df2ee9 Binary files /dev/null and b/packages/retro-ui/public/favicon.ico differ diff --git a/packages/retro-ui/public/sounds/imrcv.wav b/packages/retro-ui/public/sounds/imrcv.wav new file mode 100644 index 0000000..20cf23d Binary files /dev/null and b/packages/retro-ui/public/sounds/imrcv.wav differ diff --git a/packages/retro-ui/public/sounds/imsend.wav b/packages/retro-ui/public/sounds/imsend.wav new file mode 100644 index 0000000..a24b5e3 Binary files /dev/null and b/packages/retro-ui/public/sounds/imsend.wav differ diff --git a/packages/retro-ui/public/xp.jpeg b/packages/retro-ui/public/xp.jpeg new file mode 100644 index 0000000..3d4049e Binary files /dev/null and b/packages/retro-ui/public/xp.jpeg differ diff --git a/packages/retro-ui/src/app/layout.tsx b/packages/retro-ui/src/app/layout.tsx new file mode 100644 index 0000000..f3bef18 --- /dev/null +++ b/packages/retro-ui/src/app/layout.tsx @@ -0,0 +1,25 @@ +import React from "react"; +import "../styles/globals.css"; + +export const metadata = { + title: "Local LLM Chat", + description: "Chat with an in-browser Vicuna LLM model", +}; + +export default function RootLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + + {children} + + + ); +} diff --git a/packages/retro-ui/src/app/page.jsx b/packages/retro-ui/src/app/page.jsx new file mode 100644 index 0000000..ec1aebf --- /dev/null +++ b/packages/retro-ui/src/app/page.jsx @@ -0,0 +1,11 @@ +"use client"; +import Chat from "@/components/Chat"; +import { ModelProvider } from "@react-llm/headless"; + +export default function Home() { + return ( + + + + ); +} diff --git a/packages/retro-ui/src/assets/fonts/ms_sans_serif.woff2 b/packages/retro-ui/src/assets/fonts/ms_sans_serif.woff2 new file mode 100644 index 0000000..83ea806 Binary files /dev/null and b/packages/retro-ui/src/assets/fonts/ms_sans_serif.woff2 differ diff --git a/packages/retro-ui/src/components/Chat.tsx b/packages/retro-ui/src/components/Chat.tsx new file mode 100644 index 0000000..f6b0cfc --- /dev/null +++ b/packages/retro-ui/src/components/Chat.tsx @@ -0,0 +1,143 @@ +"use client"; +import localFont from "next/font/local"; +import { useState } from "react"; +import { Anchor, AppBar, Button } from "react95"; +import highContrast from "react95/dist/themes/highContrast"; +import matrix from "react95/dist/themes/matrix"; +import millenium from "react95/dist/themes/millenium"; +import modernDark from "react95/dist/themes/modernDark"; +import original from "react95/dist/themes/original"; +import rose from "react95/dist/themes/rose"; +import { ThemeProvider, type CSSProperties } from "styled-components"; +import ChatWindow from "./ChatWindow"; +import ConversationList from "./ConversationList"; +import Options from "./Options"; +const myFont = localFont({ src: "../assets/fonts/ms_sans_serif.woff2" }); + +export const themeList = [ + { + value: original, + label: "Original", + }, + { + value: highContrast, + label: "High Contrast", + }, + { + value: modernDark, + label: "Modern Dark", + }, + { + value: matrix, + label: "Matrix", + }, + { + value: millenium, + label: "Millenium", + }, + { + value: rose, + label: "Rose", + }, +]; + +export default function Chat() { + const [screenName, setScreenName] = useState("endlessbox5"); + const [stopStrings, setStopStrings] = useState(["user:", "assistant:"]); + const [maxTokens, setMaxTokens] = useState(100); + const [soundLevel, setSoundLevel] = useState(0.2); + const [showConversationList, setShowConversationList] = useState(false); + const [showOptions, setShowOptions] = useState(false); + const [theme, setTheme] = useState({ + value: original, + label: "Original", + }); + + return ( +
+ + {showConversationList && ( +
+ +
+ )} + {showOptions && ( +
+ +
+ )} +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
Vicuna 13B
+
+ GitHub +
+
+ + Twitter + +
+
+
+
+
+
+
+ ); +} diff --git a/packages/retro-ui/src/components/ChatWindow.jsx b/packages/retro-ui/src/components/ChatWindow.jsx new file mode 100644 index 0000000..9ae049c --- /dev/null +++ b/packages/retro-ui/src/components/ChatWindow.jsx @@ -0,0 +1,167 @@ +import useLLM from "@react-llm/headless"; +import Image from "next/image"; +import { useCallback, useEffect, useState } from "react"; +import { + Button, + TextInput, + Toolbar, + Window, + WindowContent, + WindowHeader, +} from "react95"; +import Loader from "./Loader"; +import MessageList from "./MessageList"; + +import useSound from "use-sound"; + +function ChatWindow({ + stopStrings, + maxTokens, + screenName = "endlessbox5", + assistantScreenName = "SmartestChild", + soundLevel, +}) { + const { loadingStatus, send, isGenerating, setOnMessage } = useLLM(); + const [userInput, setUserInput] = useState(""); + const [playSend] = useSound("/sounds/imsend.wav", { volume: soundLevel }); + const [playRcv] = useSound("/sounds/imrcv.wav", { volume: soundLevel }); + + useEffect(() => { + const cb = () => (resp) => { + if (resp.step === 1) { + playRcv(); + } + }; + setOnMessage(cb); + }, [setOnMessage, playRcv]); + + const handleChange = (event) => { + setUserInput(event.target.value); + }; + + const isReady = loadingStatus.progress === 1; + + const handleSubmit = useCallback(() => { + if (isGenerating || !isReady) { + return; + } + playSend(); + send(userInput, maxTokens, stopStrings); + setUserInput(""); + }, [ + userInput, + send, + isGenerating, + isReady, + maxTokens, + stopStrings, + playSend, + ]); + + useEffect(() => { + const handleKeyPress = (event) => { + if (event.key === "Enter") { + event.preventDefault(); + handleSubmit(); + } + }; + window.addEventListener("keydown", handleKeyPress); + + return () => { + window.removeEventListener("keydown", handleKeyPress); + }; + }, [handleSubmit]); + + return ( + + + Instant Message with {assistantScreenName} + + + + + + + +
+ + {/* */} +
+ {isReady && ( +
+
+
+ +
+
+
+ {isGenerating && ( + {assistantScreenName} is typing... + )} +
+
+
+ {"buddy +
+
+ +
+
+
+
+ )} + {!isReady && } +
+ + + ); +} + +export default ChatWindow; diff --git a/packages/retro-ui/src/components/ConversationList.jsx b/packages/retro-ui/src/components/ConversationList.jsx new file mode 100644 index 0000000..48c4278 --- /dev/null +++ b/packages/retro-ui/src/components/ConversationList.jsx @@ -0,0 +1,81 @@ +"use client"; +import useLLM from "@react-llm/headless"; + +import { + Button, + GroupBox, + MenuList, + MenuListItem, + ScrollView, + Separator, + TextInput, + Window, + WindowHeader, +} from "react95"; + +import { useState } from "react"; + +const ConversationList = () => { + const { allConversations, createConversation, setConversationId } = useLLM(); + const [systemPrompt, setSystemPrompt] = useState( + "A chat between a curious user and a AI chatbot named SmartestChild on AIM who responds with lowercase, frequent emojis, and 2000s internet abbreviations." + ); + const [title, setTitle] = useState("New Conversation"); + + return ( +
+ + Conversations + + + {allConversations?.map((c) => ( +
+ { + setConversationId(c.id); + }} + > +
+
{c.title}
+
{new Date(c.updatedAt).toLocaleString()}
+ +
+
+
+ ))} +
+
+ + + setTitle(event.target.value)} + /> + + + setSystemPrompt(event.target.value)} + rows={6} + /> + + + +
+
+ ); +}; + +export default ConversationList; diff --git a/packages/retro-ui/src/components/Loader.jsx b/packages/retro-ui/src/components/Loader.jsx new file mode 100644 index 0000000..9e1d74d --- /dev/null +++ b/packages/retro-ui/src/components/Loader.jsx @@ -0,0 +1,63 @@ +import useLLM from "@react-llm/headless"; +import { Button, ProgressBar } from "react95"; + +const Loader = () => { + const { loadingStatus, isReady, init, gpuDevice } = useLLM(); + if (isReady) return null; + if (loadingStatus.progress === 1) return null; + + if (gpuDevice.unsupportedReason) { + return ( +
+

Sorry, unsupported!

+

Reason: {gpuDevice.unsupportedReason}

+

+ react-llm runs models in + the browser with WebGPU and only works in Google Chrome v113 and above + on Desktop with supported GPUs. +

+
+ ); + } + + if (loadingStatus.progress == 0) { + return ( +
+
+
+ This will download the model and may take a few minutes. After the + first time, it will be cached. +
+ + +
+
+ ); + } + + return ( +
+ Loading {loadingStatus.progress * 100}% + +
+ ); +}; + +export default Loader; diff --git a/packages/retro-ui/src/components/MessageList.jsx b/packages/retro-ui/src/components/MessageList.jsx new file mode 100644 index 0000000..5123d54 --- /dev/null +++ b/packages/retro-ui/src/components/MessageList.jsx @@ -0,0 +1,52 @@ +import useLLM from "@react-llm/headless"; +import { useEffect, useRef } from "react"; +import { Frame, ScrollView } from "react95"; + +function MessageList({ + screenName = "endlessbox5", + assistantScreenName = "SmartestChild", +}) { + const scrollRef = useRef(null); + const { conversation, userRoleName } = useLLM(); + const messages = conversation?.messages || []; + + const scrollToBottom = () => { + if (scrollRef.current) { + scrollRef.current.scrollIntoView(); + } + }; + + useEffect(() => { + scrollToBottom(); + }, [conversation, messages.length]); + + return ( + + + {conversation?.messages.map((m) => ( +
+
+ + {m.role === userRoleName ? screenName : assistantScreenName} + + : {m.text} +
+
+ ))} +
+ +
+ ); +} + +export default MessageList; diff --git a/packages/retro-ui/src/components/Options.jsx b/packages/retro-ui/src/components/Options.jsx new file mode 100644 index 0000000..f117639 --- /dev/null +++ b/packages/retro-ui/src/components/Options.jsx @@ -0,0 +1,286 @@ +import useLLM from "@react-llm/headless"; +import { useEffect, useState } from "react"; +import { + Anchor, + Button, + GroupBox, + NumberInput, + Select, + Slider, + Tab, + Tabs, + TextInput, + Window, + WindowContent, + WindowHeader, +} from "react95"; + +import { themeList } from "./Chat"; + +const Options = ({ + screenName, + setScreenName, + stopStrings, + setStopStrings, + maxTokens, + setMaxTokens, + theme, + setTheme, + soundLevel, + setSoundLevel, +}) => { + const { conversation } = useLLM(); + const [activeTab, setActiveTab] = useState(0); + + return ( +
+ + Options + + setActiveTab(value)}> + About + Conversation + Settings + + {activeTab === 0 && } + {activeTab === 1 && ( + + )} + {activeTab === 2 && ( + + )} + {activeTab === 3 && } + + +
+ ); +}; + +const StatsTab = () => { + const { gpuDevice } = useLLM(); + return ( +
+ {gpuDevice.checked && !gpuDevice.unsupportedReason && ( + +
Vendor={gpuDevice.adapterInfo.vendor}
+
Architecture={gpuDevice.adapterInfo.architecture}
+
Device={gpuDevice.adapterInfo.device}
+
Description={gpuDevice.adapterInfo.description}
+
+ maxBufferSize= + {gpuDevice.adapter.limits.maxBufferSize / (1024 * 1024)} MB +
+
+ )} + +
+ ); +}; + +const ConversationTab = ({ screenName, setScreenName, conversation }) => { + const { deleteMessages, setConversationTitle } = useLLM(); + + const [title, setTitle] = useState(conversation?.title); + + useEffect(() => { + setTitle(conversation?.title); + }, [conversation?.title]); + + return ( +
+ +
+ setTitle(event.target.value)} + /> + +
+
+ +
{conversation?.id}
+
+ +
+ {conversation?.systemPrompt} +
+
+
Screen Name
+ setScreenName(e.target.value)} + placeholder="Screen Name" + /> + +
+ ); +}; + +const SettingsTab = ({ + stopStrings, + setStopStrings, + maxTokens, + setMaxTokens, + theme, + setTheme, + soundLevel, + setSoundLevel, +}) => { + const { + init, + deleteConversation, + conversation, + userRoleName, + setUserRoleName, + assistantRoleName, + setAssistantRoleName, + } = useLLM(); + return ( +
+ + + + { + if (typeof value === "number") setMaxTokens(value); + }} + /> + + + setStopStrings(e.target.value.split(","))} + placeholder="Stop Strings" + /> + + +
+
+
Bot
+ setAssistantRoleName(e.target.value)} + placeholder="Assistant Role Name" + /> +
+
+
User
+ setUserRoleName(e.target.value)} + placeholder="User Role Name" + /> +
+
+
+ + ({ value: i / 10 }))} + onChange={setSoundLevel} + /> + + +