Skip to content

Commit

Permalink
wip test
Browse files Browse the repository at this point in the history
  • Loading branch information
a-tokyo committed Dec 8, 2024
1 parent 7b2ca9c commit bf2b747
Show file tree
Hide file tree
Showing 12 changed files with 1,428 additions and 3,071 deletions.
13 changes: 8 additions & 5 deletions jest.config.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
/** @type {import('ts-jest').JestConfigWithTsJest} **/

module.exports = {
testEnvironment: 'jsdom',
testEnvironment: 'node',
transform: {
'^.+.(t|j)sx?$': ['ts-jest', {}],
},
testPathIgnorePatterns: ['./node_modules/', './dist/'],
modulePathIgnorePatterns: ['./dist/'],
collectCoverage: true,
collectCoverageFrom: [
'src/**/*.{js,ts}',
'src/**/*.{js,}',
'!**/node_modules/**',
'!**/__flow__/**',
],
coverageDirectory: './coverage',
coverageReporters: ['lcov', 'json', 'text'],
Expand All @@ -18,7 +24,4 @@ module.exports = {
statements: 77,
},
},
moduleNameMapper: {
'\\.(gif|ttf|eot|svg)$': '<rootDir>/jest/__mocks__/fileMock.js',
},
};
5 changes: 0 additions & 5 deletions jest/.eslintrc.js

This file was deleted.

1 change: 0 additions & 1 deletion jest/__mocks__/fileMock.js

This file was deleted.

20 changes: 5 additions & 15 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,21 @@
"prepublishOnly": "npm run transpile",
"build": "webpack --mode production",
"deploy": "gh-pages -d demo/dist",
"publish-demo": "npm run build && npm run deploy",
"test": "jest"
},
"peerDependencies": {
"react": ">= 16.8.0",
"react-dom": ">= 16.8.0"
"publish-demo": "npm run build && npm run deploy"
},
"peerDependencies": {},
"devDependencies": {
"@babel/cli": "^7.14.8",
"@babel/core": "^7.14.8",
"@babel/eslint-parser": "^7.19.1",
"@babel/preset-env": "^7.14.8",
"@babel/preset-react": "^7.14.5",
"@babel/preset-typescript": "^7.26.0",
"@testing-library/jest-dom": "^5.14.1",
"@testing-library/react-hooks": "^7.0.1",
"@typescript-eslint/eslint-plugin": "^5.36.1",
"@typescript-eslint/parser": "^5.36.1",
"babel-eslint": "^10.1.0",
"babel-loader": "^8.2.2",
"css-loader": "^6.2.0",
"enzyme": "^3.11.0",
"enzyme-adapter-react-16": "^1.15.6",
"enzyme-to-json": "^3.6.2",
"eslint": "^7.31.0",
"eslint-config-airbnb": "^18.2.1",
"eslint-config-prettier": "^8.3.0",
Expand All @@ -62,9 +53,6 @@
"gh-pages": "^3.2.3",
"html-webpack-plugin": "^5.3.2",
"husky": "^4.3.0",
"jest": "^27.0.6",
"jest-cli": "^27.0.6",
"jest-enzyme": "^7.1.2",
"lint-staged": "^11.1.1",
"plato": "^1.7.0",
"prettier": "^2.3.2",
Expand Down Expand Up @@ -108,7 +96,9 @@
}
],
"dependencies": {
"@types/jest": "^29.5.14",
"jest": "^29.7.0",
"openai": "^4.76.0",
"p-map": "^7.0.3"
"ts-jest": "^29.2.5"
}
}
27 changes: 14 additions & 13 deletions src/ZeroShotClassifier/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pMap from 'p-map';
import pMap from '../utils/p-map';

import { createClient, createEmbedding } from '../providers/openai';
import cosineSimilarity from '../utils/cosineSimilarity';
Expand Down Expand Up @@ -72,7 +72,7 @@ class ZeroShotClassifier {
/** Labels to classify against */
private labels: string[];
/** Labels cache */
private labelsCache: Record<string, string>;
private labelsCache: Record<string, number []>;
/** API client */
private client;

Expand All @@ -86,7 +86,7 @@ class ZeroShotClassifier {
/** Labels to classify against, can be provider later via setLabels */
labels: string[]; // labels to classify against
/** Labels cache */
labelsCache?: Record<string, string>; // labelsCache object
labelsCache?: Record<string, number []>; // labelsCache object
}) {
const {
model = 'text-embedding-3-small',
Expand Down Expand Up @@ -157,7 +157,7 @@ class ZeroShotClassifier {
async getEmbedding(
text: string,
type: 'label' | 'data' = 'data',
): Promise<string> {
): Promise<number []> {
// if labelsCache enabled -> get from labelsCache or continue
if (type === 'label' && this.labelsCache[text]) {
return this.labelsCache[text];
Expand Down Expand Up @@ -193,7 +193,6 @@ class ZeroShotClassifier {
config: ClassifyConfig = {},
): Promise<
{
text: string;
label: string;
confidence: number;
}[]
Expand All @@ -209,44 +208,46 @@ class ZeroShotClassifier {
comparingBatchSizeTop = 10,
comparingBatchSizeBottom = 10,
} = config;

// Parallel embedding computation for labels and data
const [labelsEmbeddings, dataEmbeddings] = await Promise.all([
pMap(this.labels, (el) => this.getEmbedding(el, 'label'), {
concurrency: embeddingBatchSizeLabels,
signal: undefined,
}),
pMap(data, (el) => this.getEmbedding(el), {
concurrency: embeddingBatchSizeData,
}),
]);

/** similarity getter function */
const getSimilarity = _getSimilarityFunction(similarity);

return pMap(
const result: { label: string, confidence: number }[] = await pMap(
dataEmbeddings,
async (dataEmbedding) => {
const similarities = await pMap(
labelsEmbeddings,
async (labelEmbedding) => {
async (labelEmbedding): Promise<number> => {
return getSimilarity(dataEmbedding, labelEmbedding);
},
{ concurrency: comparingBatchSizeBottom },
);

// find closest label based on similarity
const bestIndex = similarities.indexOf(
Math[similarity === 'euclidean' ? 'min' : 'max'](...similarities),
);

return {
text: dataEmbedding,
label: this.labels[bestIndex],
confidence: similarities[bestIndex], // Include confidence score
};
},
{ concurrency: comparingBatchSizeTop },
);

return result;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/classify.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ZeroShotClassifier from './ZeroShotClassifier';
import ZeroShotClassifier from './ZeroShotClassifier/index';

type ClassificationInput = ConstructorParameters<
typeof ZeroShotClassifier
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ZeroShotClassifier from './ZeroShotClassifier';
import ZeroShotClassifier from './ZeroShotClassifier/index';
import classify from './classify';

export { ZeroShotClassifier, classify };
5 changes: 4 additions & 1 deletion src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import OpenAI from 'openai';
* Create a new OpenAI client
*/
const createClient = (config) =>
// @ts-ignore
new OpenAI({
apiKey:
config.apiKey ||
Expand All @@ -15,7 +16,7 @@ const createClient = (config) =>
* Create embeddings using openAI models
*/
const createEmbedding = async (
client: OpenAI,
client,
config: Parameters<typeof client.embeddings.create>[0],
) => {
const response = await client.embeddings.create(config);
Expand All @@ -24,3 +25,5 @@ const createEmbedding = async (
};

export { createClient, createEmbedding };

export default {};
Loading

0 comments on commit bf2b747

Please sign in to comment.