Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progressive rendering #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,21 @@ To start loading the `.splat` file, call the `load` function.
splat.load();
```

By default, the library will stream in the splat data and render the mesh as it arrives. You can optionally choose to load the entire splat file before rendering by setting the `progressive` option to `false`:

```ts
splat.load({progressive: false});
```

> NOTE: Streaming requires you to set up your environment to support [`SharedArrayBuffer`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer). For more information, see [Performance](#performance).
>
> If `SharedArrayBuffer` is not enabled, the library will load the entire splat file before rendering regardless of what is specified for the `progressive` option.

You may optionally choose to provide a `THREE.LoadingManager` to track the loading progress:

```ts
const loadingManager = new THREE.LoadingManager();
splat.load(loadingManager);
splat.load({loadingManager});
```

### Adding To Scene
Expand Down Expand Up @@ -207,9 +217,28 @@ To aid development, the `MaskMesh` objects are rendered with a wireframe materia

### Performance

For devices that support `SharedArrayBuffer`, the sorting process within this library is significantly optimized. This feature enhances the efficiency of data handling, leading to faster rendering times and smoother user experiences.
For devices that support `SharedArrayBuffer`, the sorting process within this library is significantly optimized.
However, you may also be required to set additional configuration on your server to enable this feature, especially for progressive loading.

This means you must serve the page over HTTPS & provide a valid `Cross-Origin-Opener-Policy` header. For example:

```
Cross-Origin-Opener-Policy: same-origin
Cross-Origin-Embedder-Policy: require-corp
```

Users may see these errors in the console if the above steps are not taken:

```
Uncaught (in promise) DOMException: SharedArrayBuffer will only be available in a secure context.
```
```
DOMException: Failed to execute 'postMessage' on 'DedicatedWorkerGlobalScope': SharedArrayBuffer transfer requires self.crossOriginIsolated.
```

For more information, see [MDN Web Docs](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/SharedArrayBuffer#security_requirements).

To check if your device supports `SharedArrayBuffer`, please refer to <https://caniuse.com/sharedarraybuffer>
Additionally, to check if your environment supports `SharedArrayBuffer` in the first place, please refer to <https://caniuse.com/sharedarraybuffer>.

## How do i get .splat files?

Expand Down
5 changes: 5 additions & 0 deletions cpp-sorter/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ export class WasmSorter {
return transfer(bufferToTransfer, [bufferToTransfer]);
}

public updateGlobalBuffer(globalBuffer: Uint8Array = this.globalBuffer) {
this.globalBuffer = globalBuffer;
this.module?.HEAPU8.set(globalBuffer, this.globalBufferPtr);
}

public returnBuffer(buffer: ArrayBuffer): void {
this.bufferPool.returnBuffer(buffer);
}
Expand Down
24 changes: 12 additions & 12 deletions src/geometry/GaussianSplatGeometry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class GaussianSplatGeometry extends THREE.InstancedBufferGeometry {
this._sortRunning = false;
}

async load(url: string, loadingManager?: THREE.LoadingManager) {
async load(url: string, {loadingManager, progressive = true}: {loadingManager?: THREE.LoadingManager; progressive?: boolean} = {}) {
if (this.loading) {
console.warn('Geometry is already loading or loaded');
return;
Expand All @@ -65,10 +65,17 @@ export class GaussianSplatGeometry extends THREE.InstancedBufferGeometry {

try {
const loader = new SplatLoader(undefined, loadingManager);
const {data, bytesRead} = await loader.loadAsync(url);
const vertexCount = Math.floor(bytesRead / ROW_LENGTH);
const bufferInfo = trimBuffer(data, this.maxSplats, vertexCount);
this.worker = await new SortWorker(bufferInfo.vertexCount, transfer(bufferInfo.buffer, [bufferInfo.buffer.buffer]));
const {data} =
progressive && typeof SharedArrayBuffer !== 'undefined' && crossOriginIsolated
? await loader.stream(url, undefined, () => {
if (!this.worker) return;

this.worker.updateGlobalBuffer();
})
: await loader.loadAsync(url);
const vertexCount = Math.min(Math.floor(data.length / ROW_LENGTH), this.maxSplats);
const trimmedData = data.subarray(0, vertexCount * ROW_LENGTH);
this.worker = await new SortWorker(vertexCount, trimmedData);
await this.worker.load();
this.vertexCount = vertexCount;
this.initAttributes();
Expand Down Expand Up @@ -125,10 +132,3 @@ export class GaussianSplatGeometry extends THREE.InstancedBufferGeometry {
}

const ROW_LENGTH = 3 * 4 + 3 * 4 + 4 + 4;

function trimBuffer(_buffer: Uint8Array, _maxSplats: number, _vertexCount: number): {buffer: Uint8Array; vertexCount: number} {
const actualVertexCount = Math.min(_vertexCount, _maxSplats);
const actualBufferSize = ROW_LENGTH * actualVertexCount;
const buffer = _buffer.slice(0, actualBufferSize);
return {buffer, vertexCount: actualVertexCount};
}
142 changes: 136 additions & 6 deletions src/loaders/SplatLoader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@ export class SplatLoader extends THREE.Loader {
}

public load(url: string, onLoad?: (data: Uint8Array) => void, onProgress?: (event: ProgressEvent) => void, onError?: (event: ErrorEvent) => void): void {
const path = this.path === '' ? url : this.path + url;

fetch(path, {
fetch(this._getAbsoluteURL(url), {
mode: 'cors',
credentials: this.withCredentials ? 'include' : 'omit',
})
.then(req => {
if (req.status !== 200) {
if (onError) {
onError(new ErrorEvent('NetworkError', {message: `${req.status} Unable to load ${req.url}`}));
onError(this._getErrorEventForNon200Response(req));
}
this.manager.itemError(url);
return;
Expand All @@ -27,7 +25,7 @@ export class SplatLoader extends THREE.Loader {
return;
}

const data = new Uint8Array(buffer);
const data = new Uint8Array(buffer, 0, this._getRowQuantisedByteLength(buffer.byteLength));
this._processData(data, data.length, true);

if (onLoad) {
Expand All @@ -38,7 +36,7 @@ export class SplatLoader extends THREE.Loader {
})
.catch(error => {
if (onError) {
onError(new ErrorEvent('NetworkError', {message: error.message}));
onError(this._getErrorEventForFetchError(error));
}
this.manager.itemError(url);
});
Expand All @@ -59,7 +57,139 @@ export class SplatLoader extends THREE.Loader {
});
}

public stream(url: string, onLoad?: (data: Uint8Array) => void, onProgress?: (event: ProgressEvent) => void, onError?: (event: ErrorEvent) => void): Promise<{data: Uint8Array; bytesRead: number}> {
return new Promise(resolve => {
fetch(this._getAbsoluteURL(url), {
mode: 'cors',
credentials: this.withCredentials ? 'include' : 'omit',
})
.then(req => {
if (req.status !== 200) {
if (onError) {
onError(this._getErrorEventForNon200Response(req));
}
this.manager.itemError(url);
return;
}

const {headers, body} = req;

const contentLength = Number(headers.get('Content-Length'));
if (!Number.isFinite(contentLength)) {
if (onError) {
onError(new ErrorEvent('NetworkError', {message: 'Cannot stream response without `Content-Length` header'}));
}
this.manager.itemError(url);
return;
}

if (!body) {
if (onError) {
onError(new ErrorEvent('NetworkError', {message: 'Empty response body'}));
}
this.manager.itemError(url);
return;
}

const reader = body.getReader();
const buffer = new SharedArrayBuffer(contentLength);
const out = {
data: new Uint8Array(buffer),
bytesRead: 0,
};

const _onProgress = (loaded: number, total: number) => {
if (onProgress) {
onProgress(new ProgressEvent('progress', {loaded, total}));
}
};

resolve(out);
_onProgress(0, contentLength);

let incompleteRowLength = 0;
const incompleteRow = new Uint8Array(ROW_LENGTH);

const processStream = ({done, value: currBytes}: ReadableStreamReadResult<Uint8Array>) => {
if (done) {
if (incompleteRowLength > 0) {
// TODO: warn the user about trailing/incomplete data
}
if (onLoad) {
onLoad(out.data);
}
this.manager.itemEnd(url);
return;
}

if (incompleteRowLength > 0) {
// write the previous incomplete row to the buffer
for (let i = 0; i < incompleteRowLength; i++) {
out.data[out.bytesRead + i] = incompleteRow[i];
}
out.bytesRead += incompleteRowLength;
incompleteRowLength = 0;
// save a write here by always zeroing out the rest of the row during write
}

// get the length of the complete rows
const currCompleteRowsByteLength = this._getRowQuantisedByteLength(out.bytesRead + currBytes.length) - out.bytesRead;
const currRemainingByteLength = currBytes.length - currCompleteRowsByteLength;

if (currRemainingByteLength > 0) {
// store the next incomplete row to be written to the next time processStream is called
for (let i = 0; i < currRemainingByteLength; i++) {
incompleteRow[i] = currBytes[currCompleteRowsByteLength + i];
}
incompleteRow.fill(0, currRemainingByteLength);
incompleteRowLength = currRemainingByteLength;
}

// get view of only the complete rows
const currRowBytes = currBytes.subarray(0, currCompleteRowsByteLength);

// write the complete rows to the buffer
out.data.set(currRowBytes, out.bytesRead);
out.bytesRead += currCompleteRowsByteLength;

this._processData(out.data, out.bytesRead);
_onProgress(out.bytesRead, contentLength);

reader.read().then(processStream);
};
reader.read().then(processStream);
})
.catch(error => {
if (onError) {
onError(this._getErrorEventForFetchError(error));
}
this.manager.itemError(url);
});

this.manager.itemStart(url);
});
}

private _processData(data: Uint8Array, bytesRead: number, isComplete = false): void {
this.processDataCallback?.(data, bytesRead, isComplete);
}

private _getAbsoluteURL(url: string): string {
return this.path + url;
}

private _getErrorEventForNon200Response(req: Response): ErrorEvent {
return new ErrorEvent('NetworkError', {message: `${req.status} Unable to load ${req.url}`});
}

private _getErrorEventForFetchError(error: Error): ErrorEvent {
return new ErrorEvent('NetworkError', {message: error.message});
}

private _getRowQuantisedByteLength(rowLength: number): number {
return rowLength - (rowLength % ROW_LENGTH);
}
}

// TODO: find a way to share this constant with GaussianSplatGeometry
const ROW_LENGTH = 3 * 4 + 3 * 4 + 4 + 4;
4 changes: 2 additions & 2 deletions src/mesh/GaussianSplatMesh.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ export class GaussianSplatMesh extends THREE.Mesh<GaussianSplatGeometry, Gaussia
this.rotation.x = Math.PI;
}

public load(loadingManager?: THREE.LoadingManager) {
return this.geometry.load(this.url, loadingManager);
public load(config?: Parameters<GaussianSplatGeometry['load']>[1]) {
return this.geometry.load(this.url, config);
}

private _normal = new THREE.Vector3(0, 0, 1);
Expand Down
2 changes: 1 addition & 1 deletion tests/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {GaussianSplatMesh} from '../src/index';
// import {Test} from '../src/worker_test';
const bonsai = new URL('./bonsai.splat', import.meta.url).href;
const splat = new GaussianSplatMesh(bonsai, 1000000);
splat.load();
splat.load({progressive: true});
scene.add(splat);

renderer.setAnimationLoop(animation);
Expand Down