diff --git a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java new file mode 100644 index 00000000000..fc04c36b1ae --- /dev/null +++ b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/EmbeddingFunction.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.inlong.sort.function; + +import org.apache.inlong.sort.function.embedding.EmbeddingInput; +import org.apache.inlong.sort.function.embedding.LanguageModel; + +import com.google.common.base.Strings; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpResponse; +import org.apache.http.HttpStatus; +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.util.EntityUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Embedding function for LLM applications. + * */ +public class EmbeddingFunction extends ScalarFunction { + + public static final Logger logger = LoggerFactory.getLogger(EmbeddingFunction.class); + public static final String DEFAULT_EMBEDDING_FUNCTION_NAME = "EMBEDDING"; + + private final ObjectMapper mapper = new ObjectMapper(); + public static final int DEFAULT_CONNECT_TIMEOUT = 30000; + public static final int DEFAULT_SOCKET_TIMEOUT = 30000; + public static final String DEFAULT_MODEL = LanguageModel.BBAI_ZH.getModel(); + private transient HttpClient httpClient; + + /** + * Embedding a LLM document(a String object for now) via http protocol + * @param url the service url for embedding service + * @param input the source data for embedding + * @param model the language model supported in the embedding service + * */ + public String eval(String url, String input, String model) { + // url and input is not null + if (Strings.isNullOrEmpty(url) || Strings.isNullOrEmpty(input)) { + logger.error("Failed to embedding, both url and input can't be empty or null, url: {}, input: {}", + url, input); + return null; + } + + if (Strings.isNullOrEmpty(model)) { + model = DEFAULT_MODEL; + logger.info("model is null, use default model: {}", model); + } + + if (!LanguageModel.isLanguageModelSupported(model)) { + logger.error("Failed to embedding, language model {} not supported(only {} are supported right now)", + model, LanguageModel.getAllSupportedLanguageModels()); + return null; + } + + // initialize httpClient + if (httpClient == null) { + RequestConfig requestConfig = RequestConfig.custom() + .setConnectTimeout(DEFAULT_CONNECT_TIMEOUT) + .setSocketTimeout(DEFAULT_SOCKET_TIMEOUT) + .build(); + httpClient = HttpClientBuilder.create().setDefaultRequestConfig(requestConfig).build(); + } + + try { + HttpPost httpPost = new HttpPost(url); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + EmbeddingInput embeddingInput = new EmbeddingInput(input, model); + String encodedContents = mapper.writeValueAsString(embeddingInput); + httpPost.setEntity(new StringEntity(encodedContents)); + HttpResponse response = httpClient.execute(httpPost); + + String returnStr = EntityUtils.toString(response.getEntity()); + int returnCode = response.getStatusLine().getStatusCode(); + if (Strings.isNullOrEmpty(returnStr) || HttpStatus.SC_OK != returnCode) { + throw new Exception("Failed to embedding, result: " + returnStr + ", code: " + returnCode); + } + return returnStr; + } catch (Exception e) { + logger.error("Failed to embedding, url: {}, input: {}", url, input, e); + return null; + } + } +} \ No newline at end of file diff --git a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java new file mode 100644 index 00000000000..f040db31d70 --- /dev/null +++ b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/EmbeddingInput.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.inlong.sort.function.embedding; + +import java.io.Serializable; + +/** + * Class representing the input of embedding function + */ +public class EmbeddingInput implements Serializable { + + private String input; + private String model; + + public void setInput(String input) { + this.input = input; + } + + public String getInput() { + return input; + } + + public void setModel(String model) { + this.model = model; + } + + public String getModel() { + return model; + } + + public EmbeddingInput(String input, String model) { + this.input = input; + this.model = model; + } +} diff --git a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java new file mode 100644 index 00000000000..aebe9cedd00 --- /dev/null +++ b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/function/embedding/LanguageModel.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.inlong.sort.function.embedding; + +import com.google.common.base.Strings; + +/** + * Supported language model for embedding. + */ +public enum LanguageModel { + + /** + * Language model for BBAI zh, chinese + * */ + BBAI_ZH("BAAI/bge-large-zh-v1.5"), + /** + * Language model for BBAI en, english + * */ + BBAI_EN("BAAI/bge-large-en"), + /** + * Language model for intfloat multi-language + * */ + INTFLOAT_MULTI("intfloat/multilingual-e5-large"); + String model; + + LanguageModel(String s) { + this.model = s; + } + + public String getModel() { + return this.model; + } + + public static boolean isLanguageModelSupported(String s) { + if (Strings.isNullOrEmpty(s)) { + return false; + } + for (LanguageModel lm : LanguageModel.values()) { + if (s.equalsIgnoreCase(lm.getModel())) { + return true; + } + } + return false; + } + + public static String getAllSupportedLanguageModels() { + if (LanguageModel.values().length == 0) { + return null; + } + StringBuilder supportedLMBuilder = new StringBuilder(); + for (LanguageModel lm : LanguageModel.values()) { + supportedLMBuilder.append(lm.getModel()).append(","); + } + return supportedLMBuilder.substring(0, supportedLMBuilder.length() - 1); + } +} diff --git a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java index 62369e1d54c..76045495074 100644 --- a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java +++ b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/FlinkSqlParser.java @@ -23,6 +23,7 @@ import org.apache.inlong.sort.formats.common.FormatInfo; import org.apache.inlong.sort.formats.common.MapFormatInfo; import org.apache.inlong.sort.formats.common.RowFormatInfo; +import org.apache.inlong.sort.function.EmbeddingFunction; import org.apache.inlong.sort.function.EncryptFunction; import org.apache.inlong.sort.function.JsonGetterFunction; import org.apache.inlong.sort.function.RegexpReplaceFirstFunction; @@ -73,6 +74,7 @@ import java.util.stream.Stream; import static org.apache.inlong.common.util.MaskDataUtils.maskSensitiveMessage; +import static org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME; /** * Flink sql parse handler @@ -122,6 +124,7 @@ private void registerUDF() { tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", RegexpReplaceFunction.class); tableEnv.createTemporarySystemFunction("ENCRYPT", EncryptFunction.class); tableEnv.createTemporarySystemFunction("JSON_GETTER", JsonGetterFunction.class); + tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, EmbeddingFunction.class); } /** diff --git a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java index 11dd43347cf..1e254dabba9 100644 --- a/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java +++ b/inlong-sort/sort-core/src/main/java/org/apache/inlong/sort/parser/impl/NativeFlinkSqlParser.java @@ -17,6 +17,7 @@ package org.apache.inlong.sort.parser.impl; +import org.apache.inlong.sort.function.EmbeddingFunction; import org.apache.inlong.sort.function.EncryptFunction; import org.apache.inlong.sort.function.JsonGetterFunction; import org.apache.inlong.sort.function.RegexpReplaceFirstFunction; @@ -34,6 +35,8 @@ import java.util.List; import java.util.Locale; +import static org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME; + /** * parse flink sql script file * script file include CREATE TABLE statement @@ -70,6 +73,7 @@ private void registerUDF() { tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", RegexpReplaceFunction.class); tableEnv.createTemporarySystemFunction("ENCRYPT", EncryptFunction.class); tableEnv.createTemporarySystemFunction("JSON_GETTER", JsonGetterFunction.class); + tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, EmbeddingFunction.class); } /** diff --git a/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java new file mode 100644 index 00000000000..0c14ca9483b --- /dev/null +++ b/inlong-sort/sort-core/src/test/java/org/apache/inlong/sort/function/EmbeddingFunctionTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.inlong.sort.function; + +import org.apache.inlong.sort.function.embedding.EmbeddingInput; +import org.apache.inlong.sort.function.embedding.LanguageModel; + +import com.sun.net.httpserver.HttpServer; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.EnvironmentSettings; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; +import org.junit.Assert; +import org.junit.Test; + +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import static org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME; + +public class EmbeddingFunctionTest extends AbstractTestBase { + + @Test + public void testMapper() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + EmbeddingInput embeddingInput = new EmbeddingInput("Input-Test", "Model-Test"); + String encodedContents = mapper.writeValueAsString(embeddingInput); + String expect = "{\"input\":\"Input-Test\",\"model\":\"Model-Test\"}"; + Assert.assertEquals(encodedContents, expect); + } + + @Test + public void testLanguageModel() { + String supportedLMs = LanguageModel.getAllSupportedLanguageModels(); + Assert.assertNotNull(supportedLMs); + String[] supportLMArray = supportedLMs.split(","); + Assert.assertEquals(supportLMArray.length, LanguageModel.values().length); + + Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-zh-v1.5")); + Assert.assertTrue(LanguageModel.isLanguageModelSupported("BAAI/bge-large-en")); + Assert.assertTrue(LanguageModel.isLanguageModelSupported("intfloat/multilingual-e5-large")); + Assert.assertFalse(LanguageModel.isLanguageModelSupported("fake/fake-language")); + } + + /** + * Test for embedding function + * + * @throws Exception The exception may throw when test Embedding function + */ + @Test + public void testEmbeddingFunction() throws Exception { + EnvironmentSettings settings = EnvironmentSettings + .newInstance() + .inStreamingMode() + .build(); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + env.enableCheckpointing(10000); + StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env, settings); + + // step 1. Register custom function of Embedding + tableEnv.createTemporaryFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, EmbeddingFunction.class); + + List udfNames = Arrays.asList(tableEnv.listUserDefinedFunctions()); + Assert.assertTrue(udfNames.contains(DEFAULT_EMBEDDING_FUNCTION_NAME.toLowerCase(Locale.ROOT))); + + // step 2. Generate test data and convert to DataStream + int numOfMessages = 100; + List sourceDateList = new ArrayList<>(); + String msgPrefix = "Data for embedding-"; + for (int i = 0; i < numOfMessages; i++) { + sourceDateList.add(msgPrefix + i); + } + + List data = new ArrayList<>(); + sourceDateList.forEach(s -> data.add(Row.of(s))); + TypeInformation[] types = {BasicTypeInfo.STRING_TYPE_INFO}; + String[] names = {"f1"}; + RowTypeInfo typeInfo = new RowTypeInfo(types, names); + DataStream dataStream = env.fromCollection(data).returns(typeInfo); + + // step 3. start a web server to mock embedding service + String embeddingResult = "{\"result\": \"Result data for embedding\"}"; + HttpServer httpServer = HttpServer.create(new InetSocketAddress(8899), 0); // or use InetSocketAddress(0) for + // ephemeral port + httpServer.createContext("/get_embedding", exchange -> { + byte[] response = embeddingResult.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, response.length); + exchange.getResponseBody().write(response); + exchange.close(); + }); + httpServer.start(); + + // step 4. Convert from DataStream to Table and execute the Embedding function + Table tempView = tableEnv.fromDataStream(dataStream).as("f1"); + tableEnv.createTemporaryView("temp_view", tempView); + Table outputTable = tableEnv.sqlQuery( + "SELECT " + + "f1," + + "EMBEDDING('http://localhost:8899/get_embedding', f1, 'BAAI/bge-large-en') " + + "from temp_view"); + + // step 5. Get function execution result and parse it + DataStream resultSet = tableEnv.toAppendStream(outputTable, Row.class); + List resultF0 = new ArrayList<>(); + List resultF1 = new ArrayList<>(); + for (CloseableIterator it = resultSet.executeAndCollect(); it.hasNext();) { + Row row = it.next(); + if (row != null) { + resultF0.add(row.getField(0).toString()); + resultF1.add(row.getField(1).toString()); + } + } + Assert.assertEquals(resultF0.size(), numOfMessages); + Assert.assertEquals(resultF1.size(), numOfMessages); + Assert.assertEquals(resultF0, sourceDateList); + for (String res : resultF1) { + Assert.assertEquals(res, embeddingResult); + } + + httpServer.stop(0); + } +}