Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INLONG-9473][Sort] Support transform of embedding for LLM applications #9474

Merged
merged 3 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.inlong.sort.function;

import org.apache.inlong.sort.function.embedding.EmbeddingInput;
import org.apache.inlong.sort.function.embedding.LanguageModel;

import com.google.common.base.Strings;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Embedding function for LLM applications.
* */
public class EmbeddingFunction extends ScalarFunction {

public static final Logger logger = LoggerFactory.getLogger(EmbeddingFunction.class);
public static final String DEFAULT_EMBEDDING_FUNCTION_NAME = "EMBEDDING";

private final ObjectMapper mapper = new ObjectMapper();
public static final int DEFAULT_CONNECT_TIMEOUT = 30000;
public static final int DEFAULT_SOCKET_TIMEOUT = 30000;
public static final String DEFAULT_MODEL = LanguageModel.BBAI_ZH.getModel();
private transient HttpClient httpClient;

/**
* Embedding a LLM document(a String object for now) via http protocol
* @param url the service url for embedding service
* @param input the source data for embedding
* @param model the language model supported in the embedding service
* */
public String eval(String url, String input, String model) {
// url and input is not null
if (Strings.isNullOrEmpty(url) || Strings.isNullOrEmpty(input)) {
logger.error("Failed to embedding, both url and input can't be empty or null, url: {}, input: {}",
url, input);
return null;
}

if (Strings.isNullOrEmpty(model)) {
model = DEFAULT_MODEL;
logger.info("model is null, use default model: {}", model);
}

if (!LanguageModel.isLanguageModelSupported(model)) {
logger.error("Failed to embedding, language model {} not supported(only {} are supported right now)",
model, LanguageModel.getAllSupportedLanguageModels());
return null;
}

// initialize httpClient
if (httpClient == null) {
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(DEFAULT_CONNECT_TIMEOUT)
.setSocketTimeout(DEFAULT_SOCKET_TIMEOUT)
.build();
httpClient = HttpClientBuilder.create().setDefaultRequestConfig(requestConfig).build();
}

try {
HttpPost httpPost = new HttpPost(url);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, "application/json");
EmbeddingInput embeddingInput = new EmbeddingInput(input, model);
String encodedContents = mapper.writeValueAsString(embeddingInput);
httpPost.setEntity(new StringEntity(encodedContents));
HttpResponse response = httpClient.execute(httpPost);

String returnStr = EntityUtils.toString(response.getEntity());
int returnCode = response.getStatusLine().getStatusCode();
if (Strings.isNullOrEmpty(returnStr) || HttpStatus.SC_OK != returnCode) {
throw new Exception("Failed to embedding, result: " + returnStr + ", code: " + returnCode);
}
return returnStr;
} catch (Exception e) {
logger.error("Failed to embedding, url: {}, input: {}", url, input, e);
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.inlong.sort.function.embedding;

import java.io.Serializable;

/**
* Class representing the input of embedding function
*/
public class EmbeddingInput implements Serializable {

private String input;
private String model;

public void setInput(String input) {
this.input = input;
}

public String getInput() {
return input;
}

public void setModel(String model) {
this.model = model;
}

public String getModel() {
return model;
}

public EmbeddingInput(String input, String model) {
this.input = input;
this.model = model;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.inlong.sort.function.embedding;

import com.google.common.base.Strings;

/**
* Supported language model for embedding.
*/
public enum LanguageModel {

/**
* Language model for BBAI zh, chinese
* */
BBAI_ZH("BAAI/bge-large-zh-v1.5"),
/**
* Language model for BBAI en, english
* */
BBAI_EN("BAAI/bge-large-en"),
/**
* Language model for intfloat multi-language
* */
INTFLOAT_MULTI("intfloat/multilingual-e5-large");
String model;

LanguageModel(String s) {
this.model = s;
}

public String getModel() {
return this.model;
}

public static boolean isLanguageModelSupported(String s) {
if (Strings.isNullOrEmpty(s)) {
return false;
}
for (LanguageModel lm : LanguageModel.values()) {
if (s.equalsIgnoreCase(lm.getModel())) {
return true;
}
}
return false;
}

public static String getAllSupportedLanguageModels() {
if (LanguageModel.values().length == 0) {
return null;
}
StringBuilder supportedLMBuilder = new StringBuilder();
for (LanguageModel lm : LanguageModel.values()) {
supportedLMBuilder.append(lm.getModel()).append(",");
}
return supportedLMBuilder.substring(0, supportedLMBuilder.length() - 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.inlong.sort.formats.common.FormatInfo;
import org.apache.inlong.sort.formats.common.MapFormatInfo;
import org.apache.inlong.sort.formats.common.RowFormatInfo;
import org.apache.inlong.sort.function.EmbeddingFunction;
import org.apache.inlong.sort.function.EncryptFunction;
import org.apache.inlong.sort.function.JsonGetterFunction;
import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
Expand Down Expand Up @@ -73,6 +74,7 @@
import java.util.stream.Stream;

import static org.apache.inlong.common.util.MaskDataUtils.maskSensitiveMessage;
import static org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;

/**
* Flink sql parse handler
Expand Down Expand Up @@ -122,6 +124,7 @@ private void registerUDF() {
tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", RegexpReplaceFunction.class);
tableEnv.createTemporarySystemFunction("ENCRYPT", EncryptFunction.class);
tableEnv.createTemporarySystemFunction("JSON_GETTER", JsonGetterFunction.class);
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, EmbeddingFunction.class);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.inlong.sort.parser.impl;

import org.apache.inlong.sort.function.EmbeddingFunction;
import org.apache.inlong.sort.function.EncryptFunction;
import org.apache.inlong.sort.function.JsonGetterFunction;
import org.apache.inlong.sort.function.RegexpReplaceFirstFunction;
Expand All @@ -34,6 +35,8 @@
import java.util.List;
import java.util.Locale;

import static org.apache.inlong.sort.function.EmbeddingFunction.DEFAULT_EMBEDDING_FUNCTION_NAME;

/**
* parse flink sql script file
* script file include CREATE TABLE statement
Expand Down Expand Up @@ -70,6 +73,7 @@ private void registerUDF() {
tableEnv.createTemporarySystemFunction("REGEXP_REPLACE", RegexpReplaceFunction.class);
tableEnv.createTemporarySystemFunction("ENCRYPT", EncryptFunction.class);
tableEnv.createTemporarySystemFunction("JSON_GETTER", JsonGetterFunction.class);
tableEnv.createTemporarySystemFunction(DEFAULT_EMBEDDING_FUNCTION_NAME, EmbeddingFunction.class);
}

/**
Expand Down
Loading