diff --git a/UDF/tg_ExprFunctions.hpp b/UDF/tg_ExprFunctions.hpp index 79c330e9..13f7637e 100644 --- a/UDF/tg_ExprFunctions.hpp +++ b/UDF/tg_ExprFunctions.hpp @@ -440,6 +440,21 @@ namespace UDIMPL { /* =========== END APPROXIMATE NEAREST NEIGHBORS ============= */ + /* ============== START Milvus ===== */ + inline ListAccum tg_searchInMilvus( + const std::string milvus_host, int milvus_port, const std::string& collection_name, + const std::string& vector_field_name, const std::string& vertex_id_field_name, const std::string& query_vector_str, + const std::string& metric_type, int top_k) { + + tg::tg_MilvusUtil milvus_util(milvus_host, milvus_port); + + // Convert query vector string to std::vector + std::vector query_vector = milvus_util.stringToFloatVector(query_vector_str); + + std::cout << "Beginning the search on: " << collection_name << std::endl; + return milvus_util.search(collection_name, vector_field_name, vertex_id_field_name, query_vector, metric_type, top_k); + } + /* ============== END Milvus ===== */ } /****************************************/ diff --git a/UDF/tg_ExprUtil.hpp b/UDF/tg_ExprUtil.hpp index 21a432bf..8346b673 100644 --- a/UDF/tg_ExprUtil.hpp +++ b/UDF/tg_ExprUtil.hpp @@ -638,6 +638,130 @@ namespace tg { /* ============ END NODE2VEC =============== */ + /* ============== START Milvus =========== */ + class tg_MilvusUtil { + public: + tg_MilvusUtil(const std::string& host, int port) { + this->host = host; + this->port = port; + curl_global_init(CURL_GLOBAL_ALL); + } + + ~tg_MilvusUtil() { + curl_global_cleanup(); + } + + std::vector stringToFloatVector(const std::string& str, char delimiter = ',') { + std::vector result; + std::stringstream ss(str); + std::string item; + + while (std::getline(ss, item, delimiter)) { + try { + result.push_back(std::stof(item)); + } catch (const std::invalid_argument& ia) { + std::cerr << "Invalid argument: " << ia.what() << '\n'; + } catch (const std::out_of_range& oor) { + std::cerr << "Out of Range error: " << oor.what() << '\n'; + } + } + + return result; + } + + ListAccum search(const std::string& collection_name, const std::string& vector_field_name, + const std::string& vertex_id_field_name, const std::vector& query_vector, const std::string& metric_type, int top_k) const { + ListAccum vertexIdList; + + Json::Value search_body; + search_body["collectionName"] = collection_name; + + // Convert query_vector to Json::Value format + for (const auto& val : query_vector) { + search_body["vector"].append(val); + } + + search_body["outputFields"] = Json::arrayValue; + search_body["outputFields"].append("pk"); + search_body["outputFields"].append(vertex_id_field_name); + search_body["limit"] = top_k; + + // You may need to adjust 'search_body' to match the exact format expected by your Milvus server version + + CURL* curl = curl_easy_init(); + if (curl) { + CURLcode res; + std::string readBuffer; + std::string url; + + if (host.substr(0, 4) == "http" && host.find(":") != std::string::npos && host.find(std::to_string(port)) != std::string::npos) { + url = host + "/v1/vector/search"; + } else if (host.substr(0, 4) == "http") { + url = host + ":" + std::to_string(port) + "/v1/vector/search"; + } else { + url = "http://" + host + ":" + std::to_string(port) + "/v1/vector/search"; + } + + Json::StreamWriterBuilder writerBuilder; + std::string requestBody = Json::writeString(writerBuilder, search_body); + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, requestBody.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + struct curl_slist *headers = NULL; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + res = curl_easy_perform(curl); + if (res != CURLE_OK) { + std::cerr << "curl_easy_perform() failed: " << curl_easy_strerror(res) << std::endl; + } else { + Json::CharReaderBuilder readerBuilder; + Json::Value json_response; + std::unique_ptr const reader(readerBuilder.newCharReader()); + std::string parseErrors; + + bool parsingSuccessful = reader->parse(readBuffer.c_str(), readBuffer.c_str() + readBuffer.size(), &json_response, &parseErrors); + + if (parsingSuccessful) { + std::cout << "JSON successfully parsed" << std::endl; + } else { + // If parsing was unsuccessful, print the errors encountered + std::cerr << "Failed to parse JSON: " << parseErrors << std::endl; + } + + if (parsingSuccessful) { + for (const auto& item : json_response["data"]) { + std::string pk = item["pk"].asString(); + std::string vertex_id_str = item[vertex_id_field_name].asString(); + std::cout << "Vector ID: " << pk << "\tVertex ID: " << vertex_id_str << std::endl; + vertexIdList += vertex_id_str; + } + } + } + + curl_easy_cleanup(curl); + curl_slist_free_all(headers); + } + + return vertexIdList; + } + + private: + std::string host; + int port; + + static size_t WriteCallback(void *contents, size_t size, size_t nmemb, std::string *userp) { + userp->append((char*)contents, size * nmemb); + return size * nmemb; + } + }; + + /* ============== END Milvus =========== */ + /* ============== START A STAR =========== */ inline float rad(float d) { diff --git a/tools/scripts/bash_functions b/tools/scripts/bash_functions index b2d060bc..a55b0c07 100644 --- a/tools/scripts/bash_functions +++ b/tools/scripts/bash_functions @@ -151,7 +151,8 @@ function decrypt () { else QNAME=$1 fi - codegen_path="$(getAppRoot)/dev/gdk/gsql/.tmp/codeGen" + version=$(basename "$(getAppRoot)") + codegen_path="$(gadmin config get System.DataRoot)/gsql/${version}/.tmp/codeGen" if [[ x$QNAME != x ]] then FILES=$(ls ${codegen_path}/*$QNAME.cpp 2>/dev/null) @@ -181,7 +182,8 @@ function _complete_decrypt { fi local cur=${COMP_WORDS[COMP_CWORD]} - local QUERIES=$(ls $(getAppRoot)/dev/gdk/gsql/.tmp/codeGen/${cur}*.cpp 2>/dev/null) + local version=$(basename "$(getAppRoot)") + local QUERIES=$(ls "$(gadmin config get System.DataRoot)/gsql/${version}/.tmp/codeGen/${cur}*.cpp" 2>/dev/null) for QUERY in $QUERIES; do local filename=$(basename $QUERY)