Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract common logic for detectStart and detectStop into a helper class #77

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 10 additions & 98 deletions src/BodyPose/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ BodyPose
Ported from pose-detection at Tensorflow.js
*/

import * as tf from "@tensorflow/tfjs";
import * as poseDetection from "@tensorflow-models/pose-detection";
import * as tf from "@tensorflow/tfjs";
import ImageDetector from "../ImageDetector";
import callCallback from "../utils/callcallback";
import handleArguments from "../utils/handleArguments";
import { mediaReady } from "../utils/imageUtilities";
Expand Down Expand Up @@ -55,13 +56,6 @@ class BodyPose {
this.model = null;
this.config = options;
this.runtimeConfig = {};
this.detectMedia = null;
this.detectCallback = null;

// flags for detectStart() and detectStop()
this.detecting = false; // true when detection loop is running
this.signalStop = false; // Signal to stop the loop
this.prevCall = ""; // Track previous call to detectStart() or detectStop()

this.ready = callCallback(this.loadModel(), callback);
}
Expand Down Expand Up @@ -129,105 +123,23 @@ class BodyPose {
this.model = await poseDetection.createDetector(pipeline, modelConfig);

// for compatibility with p5's preload()
if (this.p5PreLoadExists) window._decrementPreload();
if (this.p5PreLoadExists()) window._decrementPreload();

return this;
}

/**
* A callback function that handles the pose detection results.
* @callback gotPoses
* @param {Array} results - An array of objects containing poses.
*/

/**
* Asynchronously outputs a single pose prediction result when called.
* @param {*} media - An HMTL or p5.js image, video, or canvas element to run the prediction on.
* @param {gotPoses} callback - A callback function to handle the predictions.
* @param {*} media - An HTML or p5.js image, video, or canvas element to run the prediction on.
* @returns {Promise<Array>} an array of poses.
*/
async detect(...inputs) {
//Parse out the input parameters
const argumentObject = handleArguments(...inputs);
argumentObject.require(
"image",
"An html or p5.js image, video, or canvas element argument is required for detect()."
);
const { image, callback } = argumentObject;

await mediaReady(image, false);
async detect(media) {
await mediaReady(media, false);
const predictions = await this.model.estimatePoses(
image,
media,
this.runtimeConfig
);
let result = predictions;
result = this.addKeypoints(result);
if (typeof callback === "function") callback(result);
return result;
}

/**
* Repeatedly outputs pose predictions through a callback function.
* Calls the internal detectLoop() function.
* @param {*} media - An HMTL or p5.js image, video, or canvas element to run the prediction on.
* @param {gotPoses} callback - A callback function to handle the predictions.
* @returns {Promise<Array>} an array of predictions.
*/
detectStart(...inputs) {
// Parse out the input parameters
const argumentObject = handleArguments(...inputs);
argumentObject.require(
"image",
"An html or p5.js image, video, or canvas element argument is required for detectStart()."
);
argumentObject.require(
"callback",
"A callback function argument is required for detectStart()."
);
this.detectMedia = argumentObject.image;
this.detectCallback = argumentObject.callback;

this.signalStop = false;
if (!this.detecting) {
this.detecting = true;
this.detectLoop();
}
if (this.prevCall === "start") {
console.warn(
"detectStart() was called more than once without calling detectStop(). The lastest detectStart() call will be used and the previous calls will be ignored."
);
}
this.prevCall = "start";
}

/**
* Internal function that calls estimatePoses in a loop
* Can be started by detectStart() and terminated by detectStop()
* @private
*/
async detectLoop() {
await mediaReady(this.detectMedia, false);
while (!this.signalStop) {
const predictions = await this.model.estimatePoses(
this.detectMedia,
this.runtimeConfig
);
let result = predictions;
result = this.addKeypoints(result);
this.detectCallback(result);
// wait for the frame to update
await tf.nextFrame();
}
this.detecting = false;
this.signalStop = false;
}

/**
* Stops the detection loop before next detection loop runs.
*/
detectStop() {
if (this.detecting) this.signalStop = true;
this.prevCall = "stop";
return this.addKeypoints(predictions);
}

/**
Expand Down Expand Up @@ -268,12 +180,12 @@ class BodyPose {

/**
* Factory function that returns a BodyPose instance.
* @returns {BodyPose} A BodyPose instance.
* @returns {ImageDetector} A BodyPose instance.
*/
const bodyPose = (...inputs) => {
const { string, options = {}, callback } = handleArguments(...inputs);
const instance = new BodyPose(string, options, callback);
return instance;
return new ImageDetector(instance);
};

export default bodyPose;
111 changes: 9 additions & 102 deletions src/BodySegmentation/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import * as tf from "@tensorflow/tfjs";
import * as tfBodySegmentation from "@tensorflow-models/body-segmentation";
import ImageDetector from "../ImageDetector";
import callCallback from "../utils/callcallback";
import handleArguments from "../utils/handleArguments";
import BODYPIX_PALETTE from "./BODYPIX_PALETTE";
Expand All @@ -18,7 +19,7 @@ import { mediaReady } from "../utils/imageUtilities";
class BodySegmentation {
/**
* Create BodyPix.
* @param {HTMLVideoElement} [video] - An HTMLVideoElement.
* @param {string} modelName
* @param {object} [options] - An object with options.
* @param {function} [callback] - A callback to be called when the model is ready.
*/
Expand All @@ -27,12 +28,10 @@ class BodySegmentation {
if (this.p5PreLoadExists()) window._incrementPreload();

this.modelName = modelName;
this.video = video;
this.model = null;
this.config = options;
this.runtimeConfig = {};
this.detectMedia = null;
this.detectCallback = null;

this.ready = callCallback(this.loadModel(), callback);
}

Expand Down Expand Up @@ -117,22 +116,17 @@ class BodySegmentation {
);

// for compatibility with p5's preload()
if (this.p5PreLoadExists) window._decrementPreload();
if (this.p5PreLoadExists()) window._decrementPreload();

return this;
}

/**
* Calls segmentPeople in a loop.
* Can be started by detectStart() and terminated by detectStop().
* @private
*/
async detect(...inputs) {
const argumentObject = handleArguments(...inputs);
argumentObject.require(
"image",
"An html or p5.js image, video, or canvas element argument is required for detectStart()."
);
const { image, callback } = argumentObject;
async detect(image) {

await mediaReady(image, false);

Expand Down Expand Up @@ -165,99 +159,12 @@ class BodySegmentation {
}
result.mask = this.generateP5Image(result.maskImageData);

if (callback) callback(result);
return result;
}
/**
* Repeatedly outputs hand predictions through a callback function.
* @param {*} [media] - An HMTL or p5.js image, video, or canvas element to run the prediction on.
* @param {gotHands} [callback] - A callback to handle the hand detection results.
*/
detectStart(...inputs) {
// Parse out the input parameters
const argumentObject = handleArguments(...inputs);
argumentObject.require(
"image",
"An html or p5.js image, video, or canvas element argument is required for detectStart()."
);
argumentObject.require(
"callback",
"A callback function argument is required for detectStart()."
);
this.detectMedia = argumentObject.image;
this.detectCallback = argumentObject.callback;

this.signalStop = false;
if (!this.detecting) {
this.detecting = true;
this.detectLoop();
}
if (this.prevCall === "start") {
console.warn(
"detectStart() was called more than once without calling detectStop(). Only the latest detectStart() call will take effect."
);
}
this.prevCall = "start";
}

/**
* Stops the detection loop before next detection loop runs.
*/
detectStop() {
if (this.detecting) this.signalStop = true;
this.prevCall = "stop";
}

/**
* Calls segmentPeople in a loop.
* Can be started by detectStart() and terminated by detectStop().
* @private
*/
async detectLoop() {
await mediaReady(this.detectMedia, false);
while (!this.signalStop) {
const segmentation = await this.model.segmentPeople(
this.detectMedia,
this.runtimeConfig
);

const result = {};
switch (this.runtimeConfig.maskType) {
case "background":
result.maskImageData = await tfBodySegmentation.toBinaryMask(
segmentation,
{ r: 0, g: 0, b: 0, a: 255 },
{ r: 0, g: 0, b: 0, a: 0 }
);
break;
case "person":
result.maskImageData = await tfBodySegmentation.toBinaryMask(
segmentation
);
break;
case "parts":
result.maskImageData = await tfBodySegmentation.toColoredMask(
segmentation,
tfBodySegmentation.bodyPixMaskValueToRainbowColor,
{ r: 255, g: 255, b: 255, a: 255 }
);
result.bodyParts = BODYPIX_PALETTE;
}
result.mask = this.generateP5Image(result.maskImageData);

this.detectCallback(result);
await tf.nextFrame();
}

this.detecting = false;
this.signalStop = false;
}

/**
* Generate a p5 image from the image data
* @param imageData - a ImageData object
* @param width - the width of the p5 image
* @param height - the height of the p5 image
* @param {ImageData} imageData - a ImageData object
* @return a p5.Image object
*/
generateP5Image(imageData) {
Expand Down Expand Up @@ -288,12 +195,12 @@ class BodySegmentation {

/**
* Factory function that returns a Facemesh instance
* @returns {Object} A new bodySegmentation instance
* @returns {ImageDetector} A new bodySegmentation instance
*/
const bodySegmentation = (...inputs) => {
const { string, options = {}, callback } = handleArguments(...inputs);
const instance = new BodySegmentation(string, options, callback);
return instance;
return new ImageDetector(instance);
};

export default bodySegmentation;
Loading