Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧹 Cleanup Neural Network task-dependent logic #95

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 20 additions & 64 deletions src/NeuralNetwork/NeuralNetwork.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import * as tf from "@tensorflow/tfjs";
import axios from "axios";
import callCallback from "../utils/callcallback";
import handleArguments from "../utils/handleArguments";
import { saveBlob } from "../utils/io";
import { randomGaussian } from "../utils/random";

Expand All @@ -11,22 +9,11 @@ class NeuralNetwork {
this.isTrained = false;
this.isCompiled = false;
this.isLayered = false;
// the model
/**
* @type {tf.Sequential | null} - the TensorFlow model
*/
this.model = null;

// methods
this.init = this.init.bind(this);
this.createModel = this.createModel.bind(this);
this.addLayer = this.addLayer.bind(this);
this.compile = this.compile.bind(this);
this.setOptimizerFunction = this.setOptimizerFunction.bind(this);
this.train = this.train.bind(this);
this.trainInternal = this.trainInternal.bind(this);
this.predict = this.predict.bind(this);
this.classify = this.classify.bind(this);
this.save = this.save.bind(this);
this.load = this.load.bind(this);

// initialize
this.init();
}
Expand Down Expand Up @@ -57,11 +44,11 @@ class NeuralNetwork {
/**
* add layer to the model
* if the model has 2 or more layers switch the isLayered flag
* @param {*} _layerOptions
* @param {tf.layers.Layer} layer
* @void
*/
addLayer(_layerOptions) {
const LAYER_OPTIONS = _layerOptions || {};
this.model.add(LAYER_OPTIONS);
addLayer(layer) {
this.model.add(layer);

// check if it has at least an input and output layer
if (this.model.layers.length >= 2) {
Expand All @@ -71,38 +58,19 @@ class NeuralNetwork {

/**
* Compile the model
* if the model is compiled, set the isCompiled flag to true
* @param {*} _modelOptions
* once the model is compiled, set the isCompiled flag to true
* @param {tf.ModelCompileArgs} compileOptions
*/
compile(_modelOptions) {
this.model.compile(_modelOptions);
compile(compileOptions) {
this.model.compile(compileOptions);
this.isCompiled = true;
}

/**
* Set the optimizer function given the learning rate
* as a parameter
* @param {*} learningRate
* @param {*} optimizer
*/
setOptimizerFunction(learningRate, optimizer) {
return optimizer.call(this, learningRate);
}

/**
* Calls the trainInternal() and calls the callback when finished
* @param {*} _options
* @param {*} _cb
*/
train(_options, _cb) {
return callCallback(this.trainInternal(_options), _cb);
}

/**
* Train the model
* @param {*} _options
* @param {tf.ModelFitArgs & { inputs: tf.Tensor, outputs: tf.Tensor, whileTraining: Array }} _options
*/
async trainInternal(_options) {
async train(_options) {
const TRAINING_OPTIONS = _options;

const xs = TRAINING_OPTIONS.inputs;
Expand Down Expand Up @@ -178,15 +146,12 @@ class NeuralNetwork {
// are the same as .predict()

/**
* save the model
* @param {*} nameOrCb
* @param {*} cb
* save the model.json and the weights.bin files
* @param {string} modelName
* @return {Promise<void>}
*/
async save(nameOrCb, cb) {
const { string, callback } = handleArguments(nameOrCb, cb);
const modelName = string || "model";

this.model.save(
async save(modelName = "model") {
await this.model.save(
tf.io.withSaveHandler(async (data) => {
this.weightsManifest = {
modelTopology: data.modelTopology,
Expand All @@ -208,19 +173,15 @@ class NeuralNetwork {
`${modelName}.json`,
"text/plain"
);
if (callback) {
callback();
}
})
);
}

/**
* loads the model and weights
* @param {*} filesOrPath
* @param {*} callback
* @param {string | FileList | Object} filesOrPath
*/
async load(filesOrPath = null, callback) {
async load(filesOrPath) {
if (filesOrPath instanceof FileList) {
const files = await Promise.all(
Array.from(filesOrPath).map(async (file) => {
Expand Down Expand Up @@ -277,11 +238,6 @@ class NeuralNetwork {
this.isCompiled = true;
this.isLayered = true;
this.isTrained = true;

if (callback) {
callback();
}
return this.model;
}

/**
Expand Down
Loading