diff --git a/.clang-format b/.clang-format index c65e7720f..1ca453f84 100644 --- a/.clang-format +++ b/.clang-format @@ -7,3 +7,104 @@ Standard: Cpp11 DerivePointerAlignment: false PointerAlignment: Right --- +Language: Java +JavaImportGroups: [ 'java', 'javax', 'javafx', 'org', 'io', 'com', 'de.gsi' ] +AccessModifierOffset: -4 +AlignAfterOpenBracket: DontAlign +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: DontAlign +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortLambdasOnASingleLine: None +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: All +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeComma +BreakAfterJavaFieldAnnotations: true +BreakStringLiterals: true +ColumnLimit: 0 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 8 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - forever # avoids { wrapped to next line + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeCategories: + - Regex: '^= 1600 || i == (buffer.length - 1)) { + chunkIndex = 0; + streamObj.acceptWaveform(chunk); // feed chunk + if (rcgOjb.isReady(streamObj)) { + rcgOjb.decodeStream(streamObj); + } + String testDate = rcgOjb.getResult(streamObj); + byte[] utf8Data = testDate.getBytes(StandardCharsets.UTF_8); + + if (utf8Data.length > 0) { + System.out.println(Float.valueOf((float) i / 16000) + ":" + new String(utf8Data)); + } + } + } + streamObj.inputFinished(); + while (rcgOjb.isReady(streamObj)) { + rcgOjb.decodeStream(streamObj); + } + + String recText = "stream:" + rcgOjb.getResult(streamObj) + "\n"; + byte[] utf8Data = recText.getBytes(StandardCharsets.UTF_8); + System.out.println(new String(utf8Data)); + rcgOjb.reSet(streamObj); + rcgOjb.releaseStream(streamObj); // release stream + rcgOjb.release(); // release recognizer + + } catch (Exception e) { + System.err.println(e); + e.printStackTrace(); + } + } + + public static void main(String[] args) { + try { + String appDir = System.getProperty("user.dir"); + System.out.println("appdir=" + appDir); + String fileName = appDir + "/test.wav"; + String cfgPath = appDir + "/modelconfig.cfg"; + String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so"; + OnlineRecognizer.setSoPath(soPath); + DecodeFile rcgDemo = new DecodeFile(fileName); + + // ***************** */ + rcgDemo.initModelWithCfg(cfgPath); + rcgDemo.streamExample(); + // **************** */ + rcgDemo.initModelWithCfg(cfgPath); + rcgDemo.simpleExample(); + + } catch (Exception e) { + System.err.println(e); + e.printStackTrace(); + } + } +} diff --git a/java-api-examples/test.wav b/java-api-examples/test.wav new file mode 100644 index 000000000..256e4afd3 Binary files /dev/null and b/java-api-examples/test.wav differ diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java new file mode 100644 index 000000000..5f4b6d16a --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointConfig.java @@ -0,0 +1,29 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class EndpointConfig { + private final EndpointRule rule1; + private final EndpointRule rule2; + private final EndpointRule rule3; + + public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) { + this.rule1 = rule1; + this.rule2 = rule2; + this.rule3 = rule3; + } + + public EndpointRule getRule1() { + return rule1; + } + + public EndpointRule getRule2() { + return rule2; + } + + public EndpointRule getRule3() { + return rule3; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java new file mode 100644 index 000000000..5a1714f64 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/EndpointRule.java @@ -0,0 +1,30 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class EndpointRule { + private final boolean mustContainNonSilence; + private final float minTrailingSilence; + private final float minUtteranceLength; + + public EndpointRule( + boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) { + this.mustContainNonSilence = mustContainNonSilence; + this.minTrailingSilence = minTrailingSilence; + this.minUtteranceLength = minUtteranceLength; + } + + public float getMinTrailingSilence() { + return minTrailingSilence; + } + + public float getMinUtteranceLength() { + return minUtteranceLength; + } + + public boolean getMustContainNonSilence() { + return mustContainNonSilence; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java new file mode 100644 index 000000000..069b7897d --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/FeatureConfig.java @@ -0,0 +1,23 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class FeatureConfig { + private final int sampleRate; + private final int featureDim; + + public FeatureConfig(int sampleRate, int featureDim) { + this.sampleRate = sampleRate; + this.featureDim = featureDim; + } + + public int getSampleRate() { + return sampleRate; + } + + public int getFeatureDim() { + return featureDim; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java new file mode 100644 index 000000000..7716fd5a8 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizer.java @@ -0,0 +1,304 @@ +/* + * // Copyright 2022-2023 by zhaoming + * // the online recognizer for sherpa-onnx, it can load config from a file + * // or by argument + */ +/* +usage example: + + String cfgpath=appdir+"/modelconfig.cfg"; + OnlineRecognizer.setSoPath(soPath); //set so lib path + + OnlineRecognizer rcgOjb = new OnlineRecognizer(); //create a recognizer + rcgOjb = new OnlineRecognizer(cfgFile); //set model config file + CreateStream streamObj=rcgOjb.CreateStream(); //create a stream for read wav data + float[] buffer = rcgOjb.readWavFile(wavfilename); // read data from file + streamObj.acceptWaveform(buffer); // feed stream with data + streamObj.inputFinished(); // tell engine you done with all data + OnlineStream ssObj[] = new OnlineStream[1]; + while (rcgOjb.isReady(streamObj)) { // engine is ready for unprocessed data + ssObj[0] = streamObj; + rcgOjb.decodeStreams(ssObj); // decode for multiple stream + // rcgOjb.DecodeStream(streamObj); // decode for single stream + } + + String recText = "simple:" + rcgOjb.getResult(streamObj) + "\n"; + byte[] utf8Data = recText.getBytes(StandardCharsets.UTF_8); + System.out.println(new String(utf8Data)); + rcgOjb.reSet(streamObj); + rcgOjb.releaseStream(streamObj); // release stream + rcgOjb.release(); // release recognizer + +*/ +package com.k2fsa.sherpa.onnx; + +import java.io.*; +import java.util.*; + +public class OnlineRecognizer { + private long ptr = 0; // this is the asr engine ptrss + + private int sampleRate = 16000; + // load config file for OnlineRecognizer + public OnlineRecognizer(String modelCfgPath) { + Map proMap = this.readProperties(modelCfgPath); + try { + int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); + this.sampleRate = sampleRate; + EndpointRule rule1 = + new EndpointRule( + false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); + EndpointRule rule2 = + new EndpointRule( + true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); + EndpointRule rule3 = + new EndpointRule( + false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + OnlineTransducerModelConfig modelCfg = + new OnlineTransducerModelConfig( + proMap.get("encoder").trim(), + proMap.get("decoder").trim(), + proMap.get("joiner").trim(), + proMap.get("tokens").trim(), + Integer.parseInt(proMap.get("num_threads").trim()), + false); + FeatureConfig featConfig = + new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), + proMap.get("decoding_method").trim(), + Integer.parseInt(proMap.get("max_active_paths").trim())); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(new Object(), rcgCfg); + + } catch (Exception e) { + System.err.println(e); + } + } + + // use for android asset_manager ANDROID_API__ >= 9 + public OnlineRecognizer(Object assetManager, String modelCfgPath) { + Map proMap = this.readProperties(modelCfgPath); + try { + int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); + this.sampleRate = sampleRate; + EndpointRule rule1 = + new EndpointRule( + false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); + EndpointRule rule2 = + new EndpointRule( + true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); + EndpointRule rule3 = + new EndpointRule( + false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + OnlineTransducerModelConfig modelCfg = + new OnlineTransducerModelConfig( + proMap.get("encoder").trim(), + proMap.get("decoder").trim(), + proMap.get("joiner").trim(), + proMap.get("tokens").trim(), + Integer.parseInt(proMap.get("num_threads").trim()), + false); + FeatureConfig featConfig = + new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), + proMap.get("decoding_method").trim(), + Integer.parseInt(proMap.get("max_active_paths").trim())); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(assetManager, rcgCfg); + + } catch (Exception e) { + System.err.println(e); + } + } + + // set onlineRecognizer by parameter + public OnlineRecognizer( + String tokens, + String encoder, + String decoder, + String joiner, + int numThreads, + int sampleRate, + int featureDim, + boolean enableEndpointDetection, + float rule1MinTrailingSilence, + float rule2MinTrailingSilence, + float rule3MinUtteranceLength, + String decodingMethod, + int maxActivePaths) { + this.sampleRate = sampleRate; + EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); + EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); + EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); + EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); + OnlineTransducerModelConfig modelCfg = + new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); + FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); + OnlineRecognizerConfig rcgCfg = + new OnlineRecognizerConfig( + featConfig, + modelCfg, + endCfg, + enableEndpointDetection, + decodingMethod, + maxActivePaths); + // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 + this.ptr = createOnlineRecognizer(new Object(), rcgCfg); + } + + private Map readProperties(String modelCfgPath) { + // read and parse config file + Properties props = new Properties(); + Map proMap = new HashMap<>(); + try { + File file = new File(modelCfgPath); + if (!file.exists()) { + System.out.println("model cfg file not exists!"); + System.exit(0); + } + InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath)); + props.load(in); + Enumeration en = props.propertyNames(); + while (en.hasMoreElements()) { + String key = (String) en.nextElement(); + String Property = props.getProperty(key); + proMap.put(key, Property); + // System.out.println(key+"="+Property); + } + + } catch (Exception e) { + e.printStackTrace(); + } + return proMap; + } + + public void decodeStream(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + // when feeded samples to engine, call DecodeStream to let it process + decodeStream(this.ptr, streamPtr); + } + + public void decodeStreams(OnlineStream[] ssOjb) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + // decode for multiple streams + long[] ss = new long[ssOjb.length]; + for (int i = 0; i < ssOjb.length; i++) { + ss[i] = ssOjb[i].getPtr(); + if (ss[i] == 0) throw new Exception("null exception for stream ptr"); + } + decodeStreams(this.ptr, ss); + } + + public boolean isReady(OnlineStream s) throws Exception { + // whether the engine is ready for decode + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return isReady(this.ptr, streamPtr); + } + + public String getResult(OnlineStream s) throws Exception { + // get text from the engine + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return getResult(this.ptr, streamPtr); + } + + public boolean isEndpoint(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + return isEndpoint(this.ptr, streamPtr); + } + + public void reSet(OnlineStream s) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = s.getPtr(); + if (streamPtr == 0) throw new Exception("null exception for stream ptr"); + reSet(this.ptr, streamPtr); + } + + public OnlineStream createStream() throws Exception { + // create one stream for data to feed in + if (this.ptr == 0) throw new Exception("null exception for recognizer ptr"); + long streamPtr = createStream(this.ptr); + OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate); + return stream; + } + + public float[] readWavFile(String fileName) { + // read data from the filename + Object[] wavdata = readWave(fileName); + Object data = wavdata[0]; // data[0] is float data, data[1] sample rate + + float[] floatData = (float[]) data; + + return floatData; + } + + // load the libsherpa-onnx-jni.so lib + public static void loadSoLib(String soPath) { + // load libsherpa-onnx-jni.so lib from the path + + System.out.println("so lib path=" + soPath + "\n"); + System.load(soPath.trim()); + } + + public static void setSoPath(String soPath) { + OnlineRecognizer.loadSoLib(soPath); + OnlineStream.loadSoLib(soPath); + } + + protected void finalize() throws Throwable { + release(); + } + + // recognizer release, you'd better call it manually if not use anymore + public void release() { + if (this.ptr == 0) return; + deleteOnlineRecognizer(this.ptr); + this.ptr = 0; + } + + // stream release, you'd better call it manually if not use anymore + public void releaseStream(OnlineStream s) { + s.release(); + } + // JNI interface libsherpa-onnx-jni.so + + private native Object[] readWave(String fileName); + + private native String getResult(long ptr, long streamPtr); + + private native void decodeStream(long ptr, long streamPtr); + + private native void decodeStreams(long ptr, long[] ssPtr); + + private native boolean isReady(long ptr, long streamPtr); + + // first parameter keep for android asset_manager ANDROID_API__ >= 9 + private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config); + + private native long createStream(long ptr); + + private native void deleteOnlineRecognizer(long ptr); + + private native boolean isEndpoint(long ptr, long streamPtr); + + private native void reSet(long ptr, long streamPtr); +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java new file mode 100644 index 000000000..3b8e05ecf --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineRecognizerConfig.java @@ -0,0 +1,53 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class OnlineRecognizerConfig { + private final FeatureConfig featConfig; + private final OnlineTransducerModelConfig modelConfig; + private final EndpointConfig endpointConfig; + private final boolean enableEndpoint; + private final String decodingMethod; + private final int maxActivePaths; + + public OnlineRecognizerConfig( + FeatureConfig featConfig, + OnlineTransducerModelConfig modelConfig, + EndpointConfig endpointConfig, + boolean enableEndpoint, + String decodingMethod, + int maxActivePaths) { + this.featConfig = featConfig; + this.modelConfig = modelConfig; + this.endpointConfig = endpointConfig; + this.enableEndpoint = enableEndpoint; + this.decodingMethod = decodingMethod; + this.maxActivePaths = maxActivePaths; + } + + public FeatureConfig getFeatConfig() { + return featConfig; + } + + public OnlineTransducerModelConfig getModelConfig() { + return modelConfig; + } + + public EndpointConfig getEndpointConfig() { + return endpointConfig; + } + + public boolean isEnableEndpoint() { + return enableEndpoint; + } + + public String getDecodingMethod() { + return decodingMethod; + } + + public int getMaxActivePaths() { + return maxActivePaths; + } +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java new file mode 100644 index 000000000..557b4d8dc --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineStream.java @@ -0,0 +1,86 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ +// Stream is used for feeding data to the asr engine +package com.k2fsa.sherpa.onnx; + +import java.io.*; +import java.util.*; + +public class OnlineStream { + private long ptr = 0; // this is the stream ptr + + private int sampleRate = 16000; + // assign ptr to this stream in construction + public OnlineStream(long ptr, int sampleRate) { + this.ptr = ptr; + this.sampleRate = sampleRate; + } + + public long getPtr() { + return ptr; + } + + public void acceptWaveform(float[] samples) throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + + // feed wave data to asr engine + acceptWaveform(this.ptr, this.sampleRate, samples); + } + + public void inputFinished() { + // add some tail padding + int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate + float tailPaddings[] = new float[padLen]; // default value is 0 + acceptWaveform(this.ptr, this.sampleRate, tailPaddings); + + // tell the engine all data are feeded + inputFinished(this.ptr); + } + + public static void loadSoLib(String soPath) { + // load .so lib from the path + System.load(soPath.trim()); // ("sherpa-onnx-jni-java"); + } + + public void release() { + // stream object must be release after used + if (this.ptr == 0) return; + deleteStream(this.ptr); + this.ptr = 0; + } + + protected void finalize() throws Throwable { + release(); + } + + public boolean isLastFrame() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + return isLastFrame(this.ptr); + } + + public void reSet() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + reSet(this.ptr); + } + + public int featureDim() throws Exception { + if (this.ptr == 0) throw new Exception("null exception for stream ptr"); + return featureDim(this.ptr); + } + + // JNI interface libsherpa-onnx-jni.so + private native void acceptWaveform(long ptr, int sampleRate, float[] samples); + + private native void inputFinished(long ptr); + + private native void deleteStream(long ptr); + + private native int numFramesReady(long ptr); + + private native boolean isLastFrame(long ptr); + + private native void reSet(long ptr); + + private native int featureDim(long ptr); +} diff --git a/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java new file mode 100644 index 000000000..1e45e3717 --- /dev/null +++ b/sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx/OnlineTransducerModelConfig.java @@ -0,0 +1,48 @@ +/* + * // Copyright 2022-2023 by zhaoming + */ + +package com.k2fsa.sherpa.onnx; + +public class OnlineTransducerModelConfig { + private final String encoder; + private final String decoder; + private final String joiner; + private final String tokens; + private final int numThreads; + private final boolean debug; + + public OnlineTransducerModelConfig( + String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) { + this.encoder = encoder; + this.decoder = decoder; + this.joiner = joiner; + this.tokens = tokens; + this.numThreads = numThreads; + this.debug = debug; + } + + public String getEncoder() { + return encoder; + } + + public String getDecoder() { + return decoder; + } + + public String getJoiner() { + return joiner; + } + + public String getTokens() { + return tokens; + } + + public int getNumThreads() { + return numThreads; + } + + public boolean getDebug() { + return debug; + } +} diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index bb81ec583..8a861e8f8 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -2,6 +2,7 @@ // // Copyright (c) 2022-2023 Xiaomi Corporation // 2022 Pingfeng Luo +// 2023 Zhaoming // TODO(fangjun): Add documentation to functions/methods in this file // and also show how to use them with kotlin, possibly with java. @@ -12,7 +13,6 @@ #include #include - #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" @@ -207,7 +207,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); } #endif - auto config = sherpa_onnx::GetConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); auto model = new sherpa_onnx::SherpaOnnx( @@ -301,7 +300,7 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); exit(-1); } - + SHERPA_ONNX_LOGE("Failed to read %s", p_filename); std::vector buffer = sherpa_onnx::ReadFile(mgr, p_filename); std::istrstream is(buffer.data(), buffer.size()); @@ -332,3 +331,186 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( return obj_arr; } + +// ******warpper for OnlineRecognizer******* + +// wav reader for java interface +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env, + jclass /*cls*/, + jstring filename) { + auto data = Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( + env, nullptr, nullptr, filename); + return data; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL + +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createOnlineRecognizer( + + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + } +#endif + sherpa_onnx::OnlineRecognizerConfig config = + sherpa_onnx::GetConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto p_recognizer = new sherpa_onnx::OnlineRecognizer( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + return (jlong)p_recognizer; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_deleteOnlineRecognizer( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + std::unique_ptr s = + reinterpret_cast(ptr)->CreateStream(); + sherpa_onnx::OnlineStream *p_stream = s.release(); + return reinterpret_cast(p_stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + return model->IsReady(s); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStream( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + model->DecodeStream(s); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decodeStreams(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlongArray ss_ptr, + jint stream_size) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + jlong *p = env->GetLongArrayElements(ss_ptr, nullptr); + jsize n = env->GetArrayLength(ss_ptr); + std::vector p_ss(n); + for (int32_t i = 0; i != n; ++i) { + p_ss[i] = reinterpret_cast(p[i]); + } + + model->DecodeStreams(p_ss.data(), n); + env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); + return env->NewStringUTF(result.ToString().c_str()); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + return model->IsEndpoint(s); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reSet( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong s_ptr) { + sherpa_onnx::OnlineRecognizer *model = + reinterpret_cast(ptr); + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + model->Reset(s); +} + +// *********for OnlineStream ********* +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint sample_rate, + jfloatArray waveform) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + jfloat *p = env->GetFloatArrayElements(waveform, nullptr); + jsize n = env->GetArrayLength(waveform); + s->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(waveform, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished( + JNIEnv *env, jobject /*obj*/, jlong s_ptr) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + s->InputFinished(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_deleteStream( + JNIEnv *env, jobject /*obj*/, jlong s_ptr) { + delete reinterpret_cast(s_ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_numFramesReady( + JNIEnv *env, jobject /*obj*/, jlong s_ptr) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + return s->NumFramesReady(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_isLastFrame( + JNIEnv *env, jobject /*obj*/, jlong s_ptr, jint frame) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + return s->IsLastFrame(frame); +} +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_reSet( + JNIEnv *env, jobject /*obj*/, jlong s_ptr) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + s->Reset(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_featureDim( + JNIEnv *env, jobject /*obj*/, jlong s_ptr) { + sherpa_onnx::OnlineStream *s = + reinterpret_cast(s_ptr); + return s->FeatureDim(); +}