diff --git a/.gitignore b/.gitignore index 1523dee..a43f896 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,6 @@ tags # Persistent undo [._]*.un~ + +src/bin +src/bazel** diff --git a/src/.bazelversion b/src/.bazelversion new file mode 100644 index 0000000..815da58 --- /dev/null +++ b/src/.bazelversion @@ -0,0 +1 @@ +7.4.1 diff --git a/src/BUILD b/src/BUILD new file mode 100644 index 0000000..b28b04f --- /dev/null +++ b/src/BUILD @@ -0,0 +1,3 @@ + + + diff --git a/src/MODULE.bazel b/src/MODULE.bazel new file mode 100644 index 0000000..00bb183 --- /dev/null +++ b/src/MODULE.bazel @@ -0,0 +1,6 @@ +############################################################################### +# Bazel now uses Bzlmod by default to manage external dependencies. +# Please consider migrating your external dependencies from WORKSPACE to MODULE.bazel. +# +# For more details, please check https://github.com/bazelbuild/bazel/issues/18958 +############################################################################### diff --git a/src/MODULE.bazel.lock b/src/MODULE.bazel.lock new file mode 100644 index 0000000..d62a47c --- /dev/null +++ b/src/MODULE.bazel.lock @@ -0,0 +1,110 @@ +{ + "lockFileVersion": 11, + "registryFileHashes": { + "https://bcr.bazel.build/bazel_registry.json": "8a28e4aff06ee60aed2a8c281907fb8bcbf3b753c91fb5a5c57da3215d5b3497", + "https://bcr.bazel.build/modules/abseil-cpp/20210324.2/MODULE.bazel": "7cd0312e064fde87c8d1cd79ba06c876bd23630c83466e9500321be55c96ace2", + "https://bcr.bazel.build/modules/abseil-cpp/20211102.0/MODULE.bazel": "70390338f7a5106231d20620712f7cccb659cd0e9d073d1991c038eb9fc57589", + "https://bcr.bazel.build/modules/abseil-cpp/20211102.0/source.json": "7e3a9adf473e9af076ae485ed649d5641ad50ec5c11718103f34de03170d94ad", + "https://bcr.bazel.build/modules/apple_support/1.5.0/MODULE.bazel": "50341a62efbc483e8a2a6aec30994a58749bd7b885e18dd96aa8c33031e558ef", + "https://bcr.bazel.build/modules/apple_support/1.5.0/source.json": "eb98a7627c0bc486b57f598ad8da50f6625d974c8f723e9ea71bd39f709c9862", + "https://bcr.bazel.build/modules/bazel_features/1.11.0/MODULE.bazel": "f9382337dd5a474c3b7d334c2f83e50b6eaedc284253334cf823044a26de03e8", + "https://bcr.bazel.build/modules/bazel_features/1.11.0/source.json": "c9320aa53cd1c441d24bd6b716da087ad7e4ff0d9742a9884587596edfe53015", + "https://bcr.bazel.build/modules/bazel_skylib/1.0.3/MODULE.bazel": "bcb0fd896384802d1ad283b4e4eb4d718eebd8cb820b0a2c3a347fb971afd9d8", + "https://bcr.bazel.build/modules/bazel_skylib/1.2.1/MODULE.bazel": "f35baf9da0efe45fa3da1696ae906eea3d615ad41e2e3def4aeb4e8bc0ef9a7a", + "https://bcr.bazel.build/modules/bazel_skylib/1.3.0/MODULE.bazel": "20228b92868bf5cfc41bda7afc8a8ba2a543201851de39d990ec957b513579c5", + "https://bcr.bazel.build/modules/bazel_skylib/1.6.1/MODULE.bazel": "8fdee2dbaace6c252131c00e1de4b165dc65af02ea278476187765e1a617b917", + "https://bcr.bazel.build/modules/bazel_skylib/1.6.1/source.json": "082ed5f9837901fada8c68c2f3ddc958bb22b6d654f71dd73f3df30d45d4b749", + "https://bcr.bazel.build/modules/buildozer/7.1.2/MODULE.bazel": "2e8dd40ede9c454042645fd8d8d0cd1527966aa5c919de86661e62953cd73d84", + "https://bcr.bazel.build/modules/buildozer/7.1.2/source.json": "c9028a501d2db85793a6996205c8de120944f50a0d570438fcae0457a5f9d1f8", + "https://bcr.bazel.build/modules/googletest/1.11.0/MODULE.bazel": "3a83f095183f66345ca86aa13c58b59f9f94a2f81999c093d4eeaa2d262d12f4", + "https://bcr.bazel.build/modules/googletest/1.11.0/source.json": "c73d9ef4268c91bd0c1cd88f1f9dfa08e814b1dbe89b5f594a9f08ba0244d206", + "https://bcr.bazel.build/modules/platforms/0.0.4/MODULE.bazel": "9b328e31ee156f53f3c416a64f8491f7eb731742655a47c9eec4703a71644aee", + "https://bcr.bazel.build/modules/platforms/0.0.5/MODULE.bazel": "5733b54ea419d5eaf7997054bb55f6a1d0b5ff8aedf0176fef9eea44f3acda37", + "https://bcr.bazel.build/modules/platforms/0.0.6/MODULE.bazel": "ad6eeef431dc52aefd2d77ed20a4b353f8ebf0f4ecdd26a807d2da5aa8cd0615", + "https://bcr.bazel.build/modules/platforms/0.0.7/MODULE.bazel": "72fd4a0ede9ee5c021f6a8dd92b503e089f46c227ba2813ff183b71616034814", + "https://bcr.bazel.build/modules/platforms/0.0.9/MODULE.bazel": "4a87a60c927b56ddd67db50c89acaa62f4ce2a1d2149ccb63ffd871d5ce29ebc", + "https://bcr.bazel.build/modules/platforms/0.0.9/source.json": "cd74d854bf16a9e002fb2ca7b1a421f4403cda29f824a765acd3a8c56f8d43e6", + "https://bcr.bazel.build/modules/protobuf/21.7/MODULE.bazel": "a5a29bb89544f9b97edce05642fac225a808b5b7be74038ea3640fae2f8e66a7", + "https://bcr.bazel.build/modules/protobuf/21.7/source.json": "bbe500720421e582ff2d18b0802464205138c06056f443184de39fbb8187b09b", + "https://bcr.bazel.build/modules/protobuf/3.19.0/MODULE.bazel": "6b5fbb433f760a99a22b18b6850ed5784ef0e9928a72668b66e4d7ccd47db9b0", + "https://bcr.bazel.build/modules/protobuf/3.19.6/MODULE.bazel": "9233edc5e1f2ee276a60de3eaa47ac4132302ef9643238f23128fea53ea12858", + "https://bcr.bazel.build/modules/rules_cc/0.0.1/MODULE.bazel": "cb2aa0747f84c6c3a78dad4e2049c154f08ab9d166b1273835a8174940365647", + "https://bcr.bazel.build/modules/rules_cc/0.0.2/MODULE.bazel": "6915987c90970493ab97393024c156ea8fb9f3bea953b2f3ec05c34f19b5695c", + "https://bcr.bazel.build/modules/rules_cc/0.0.8/MODULE.bazel": "964c85c82cfeb6f3855e6a07054fdb159aced38e99a5eecf7bce9d53990afa3e", + "https://bcr.bazel.build/modules/rules_cc/0.0.9/MODULE.bazel": "836e76439f354b89afe6a911a7adf59a6b2518fafb174483ad78a2a2fde7b1c5", + "https://bcr.bazel.build/modules/rules_cc/0.0.9/source.json": "1f1ba6fea244b616de4a554a0f4983c91a9301640c8fe0dd1d410254115c8430", + "https://bcr.bazel.build/modules/rules_java/4.0.0/MODULE.bazel": "5a78a7ae82cd1a33cef56dc578c7d2a46ed0dca12643ee45edbb8417899e6f74", + "https://bcr.bazel.build/modules/rules_java/7.6.5/MODULE.bazel": "481164be5e02e4cab6e77a36927683263be56b7e36fef918b458d7a8a1ebadb1", + "https://bcr.bazel.build/modules/rules_java/7.6.5/source.json": "a805b889531d1690e3c72a7a7e47a870d00323186a9904b36af83aa3d053ee8d", + "https://bcr.bazel.build/modules/rules_jvm_external/4.4.2/MODULE.bazel": "a56b85e418c83eb1839819f0b515c431010160383306d13ec21959ac412d2fe7", + "https://bcr.bazel.build/modules/rules_jvm_external/4.4.2/source.json": "a075731e1b46bc8425098512d038d416e966ab19684a10a34f4741295642fc35", + "https://bcr.bazel.build/modules/rules_license/0.0.3/MODULE.bazel": "627e9ab0247f7d1e05736b59dbb1b6871373de5ad31c3011880b4133cafd4bd0", + "https://bcr.bazel.build/modules/rules_license/0.0.7/MODULE.bazel": "088fbeb0b6a419005b89cf93fe62d9517c0a2b8bb56af3244af65ecfe37e7d5d", + "https://bcr.bazel.build/modules/rules_license/0.0.7/source.json": "355cc5737a0f294e560d52b1b7a6492d4fff2caf0bef1a315df5a298fca2d34a", + "https://bcr.bazel.build/modules/rules_pkg/0.7.0/MODULE.bazel": "df99f03fc7934a4737122518bb87e667e62d780b610910f0447665a7e2be62dc", + "https://bcr.bazel.build/modules/rules_pkg/0.7.0/source.json": "c2557066e0c0342223ba592510ad3d812d4963b9024831f7f66fd0584dd8c66c", + "https://bcr.bazel.build/modules/rules_proto/4.0.0/MODULE.bazel": "a7a7b6ce9bee418c1a760b3d84f83a299ad6952f9903c67f19e4edd964894e06", + "https://bcr.bazel.build/modules/rules_proto/5.3.0-21.7/MODULE.bazel": "e8dff86b0971688790ae75528fe1813f71809b5afd57facb44dad9e8eca631b7", + "https://bcr.bazel.build/modules/rules_proto/5.3.0-21.7/source.json": "d57902c052424dfda0e71646cb12668d39c4620ee0544294d9d941e7d12bc3a9", + "https://bcr.bazel.build/modules/rules_python/0.10.2/MODULE.bazel": "cc82bc96f2997baa545ab3ce73f196d040ffb8756fd2d66125a530031cd90e5f", + "https://bcr.bazel.build/modules/rules_python/0.22.1/MODULE.bazel": "26114f0c0b5e93018c0c066d6673f1a2c3737c7e90af95eff30cfee38d0bbac7", + "https://bcr.bazel.build/modules/rules_python/0.22.1/source.json": "57226905e783bae7c37c2dd662be078728e48fa28ee4324a7eabcafb5a43d014", + "https://bcr.bazel.build/modules/rules_python/0.4.0/MODULE.bazel": "9208ee05fd48bf09ac60ed269791cf17fb343db56c8226a720fbb1cdf467166c", + "https://bcr.bazel.build/modules/stardoc/0.5.1/MODULE.bazel": "1a05d92974d0c122f5ccf09291442580317cdd859f07a8655f1db9a60374f9f8", + "https://bcr.bazel.build/modules/stardoc/0.5.1/source.json": "a96f95e02123320aa015b956f29c00cb818fa891ef823d55148e1a362caacf29", + "https://bcr.bazel.build/modules/upb/0.0.0-20220923-a547704/MODULE.bazel": "7298990c00040a0e2f121f6c32544bab27d4452f80d9ce51349b1a28f3005c43", + "https://bcr.bazel.build/modules/upb/0.0.0-20220923-a547704/source.json": "f1ef7d3f9e0e26d4b23d1c39b5f5de71f584dd7d1b4ef83d9bbba6ec7a6a6459", + "https://bcr.bazel.build/modules/zlib/1.2.11/MODULE.bazel": "07b389abc85fdbca459b69e2ec656ae5622873af3f845e1c9d80fe179f3effa0", + "https://bcr.bazel.build/modules/zlib/1.2.12/MODULE.bazel": "3b1a8834ada2a883674be8cbd36ede1b6ec481477ada359cd2d3ddc562340b27", + "https://bcr.bazel.build/modules/zlib/1.3.1.bcr.3/MODULE.bazel": "af322bc08976524477c79d1e45e241b6efbeb918c497e8840b8ab116802dda79", + "https://bcr.bazel.build/modules/zlib/1.3.1.bcr.3/source.json": "2be409ac3c7601245958cd4fcdff4288be79ed23bd690b4b951f500d54ee6e7d" + }, + "selectedYankedVersions": {}, + "moduleExtensions": { + "@@apple_support~//crosstool:setup.bzl%apple_cc_configure_extension": { + "general": { + "bzlTransitiveDigest": "PjIds3feoYE8SGbbIq2SFTZy3zmxeO2tQevJZNDo7iY=", + "usagesDigest": "+hz7IHWN6A1oVJJWNDB6yZRG+RYhF76wAYItpAeIUIg=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "local_config_apple_cc_toolchains": { + "bzlFile": "@@apple_support~//crosstool:setup.bzl", + "ruleClassName": "_apple_cc_autoconf_toolchains", + "attributes": {} + }, + "local_config_apple_cc": { + "bzlFile": "@@apple_support~//crosstool:setup.bzl", + "ruleClassName": "_apple_cc_autoconf", + "attributes": {} + } + }, + "recordedRepoMappingEntries": [ + [ + "apple_support~", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, + "@@platforms//host:extension.bzl%host_platform": { + "general": { + "bzlTransitiveDigest": "xelQcPZH8+tmuOHVjL9vDxMnnQNMlwj0SlvgoqBkm4U=", + "usagesDigest": "pCYpDQmqMbmiiPI1p2Kd3VLm5T48rRAht5WdW0X2GlA=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "host_platform": { + "bzlFile": "@@platforms//host:extension.bzl", + "ruleClassName": "host_platform_repo", + "attributes": {} + } + }, + "recordedRepoMappingEntries": [] + } + } + } +} diff --git a/src/WORKSPACE b/src/WORKSPACE new file mode 100644 index 0000000..3e79a40 --- /dev/null +++ b/src/WORKSPACE @@ -0,0 +1,42 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "com_github_grpc_grpc", + strip_prefix = "grpc-tags-v1.63.0", + urls = ["https://github.com/grpc/grpc/archive/tags/v1.63.0.tar.gz"], +) + +http_archive( + name = "com_github_singnet_das_proto", + strip_prefix = "das-proto-0.1.13", + urls = ["https://github.com/singnet/das-proto/archive/refs/tags/0.1.13.tar.gz"], +) + +http_archive( + name = "com_github_singnet_das_node", + strip_prefix = "das-node-ab-test-1", + urls = ["https://github.com/singnet/das-node/archive/refs/tags/ab-test-1.tar.gz"], +) + +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +grpc_deps() +load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") +grpc_extra_deps() + +http_archive( + name = "com_github_google_googletest", + strip_prefix = "googletest-1.14.0", + urls = ["https://github.com/google/googletest/archive/refs/tags/v1.14.0.tar.gz"] +) + +new_local_repository( + name = "mbedcrypto", + path = "/opt/3rd-party/mbedcrypto", + build_file_content = '\ +cc_library(\ + name = "lib",\ + srcs = ["libmbedcrypto.a"],\ + visibility = ["//visibility:public"],\ +)\ +' +) diff --git a/src/ab.sh b/src/ab.sh new file mode 100755 index 0000000..a49ab89 --- /dev/null +++ b/src/ab.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +rm -f ../bin/* && ../scripts/bazel_build.sh && ../bin/attention_broker 37007 diff --git a/src/assets/3rd-party.tgz b/src/assets/3rd-party.tgz new file mode 100644 index 0000000..417b10a Binary files /dev/null and b/src/assets/3rd-party.tgz differ diff --git a/src/assets/hiredis-cluster.tgz b/src/assets/hiredis-cluster.tgz new file mode 100644 index 0000000..a0f289d Binary files /dev/null and b/src/assets/hiredis-cluster.tgz differ diff --git a/src/assets/mongo-cxx-driver-r3.11.0.tar.gz b/src/assets/mongo-cxx-driver-r3.11.0.tar.gz new file mode 100644 index 0000000..5713bd7 Binary files /dev/null and b/src/assets/mongo-cxx-driver-r3.11.0.tar.gz differ diff --git a/src/cpp/BUILD b/src/cpp/BUILD new file mode 100644 index 0000000..f5b0d85 --- /dev/null +++ b/src/cpp/BUILD @@ -0,0 +1,62 @@ +cc_binary( + name = "attention_broker_service", + srcs = [], + defines = ["BAZEL_BUILD"], + deps = [ + "//cpp/main:attention_broker_main_lib", + "//cpp/utils:utils_lib", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_binary( + name = "query_broker", + srcs = [], + defines = ["BAZEL_BUILD"], + deps = [ + "//cpp/main:query_engine_main_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_binary( + name = "query", + srcs = [], + defines = ["BAZEL_BUILD"], + deps = [ + "//cpp/main:query_client_main_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_binary( + name = "link_creation_engine", + srcs = [], + defines = ["BAZEL_BUILD"], + deps = [ + "//cpp/main:link_creation_engine_main_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_binary( + name = "word_query", + srcs = [], + defines = ["BAZEL_BUILD"], + deps = [ + "//cpp/main:word_query_main_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) diff --git a/src/cpp/attention_broker/AttentionBrokerServer.cc b/src/cpp/attention_broker/AttentionBrokerServer.cc new file mode 100644 index 0000000..9193c6b --- /dev/null +++ b/src/cpp/attention_broker/AttentionBrokerServer.cc @@ -0,0 +1,140 @@ +#include "RequestSelector.h" +#include "AttentionBrokerServer.h" + +using namespace attention_broker_server; + +const double AttentionBrokerServer::RENT_RATE; +const double AttentionBrokerServer::SPREADING_RATE_LOWERBOUND; +const double AttentionBrokerServer::SPREADING_RATE_UPPERBOUND; + +// -------------------------------------------------------------------------------- +// Public methods + +AttentionBrokerServer::AttentionBrokerServer() { + this->global_context = "global"; + this->stimulus_requests = new SharedQueue(); + this->correlation_requests = new SharedQueue(); + this->worker_threads = new WorkerThreads(stimulus_requests, correlation_requests); + HebbianNetwork *network = new HebbianNetwork(); + this->hebbian_network[this->global_context] = network; + this->updater = HebbianNetworkUpdater::factory(HebbianNetworkUpdaterType::EXACT_COUNT); + this->stimulus_spreader = StimulusSpreader::factory(StimulusSpreaderType::TOKEN); + +} + +AttentionBrokerServer::~AttentionBrokerServer() { + graceful_shutdown(); + delete this->worker_threads; + delete this->stimulus_requests; + delete this->correlation_requests; + delete this->updater; + delete this->stimulus_spreader; + for (auto pair:this->hebbian_network) { + delete pair.second; + } +} + +void AttentionBrokerServer::graceful_shutdown() { + this->rpc_api_enabled = false; + this->worker_threads->graceful_stop(); +} + +// RPC API + +Status AttentionBrokerServer::ping(ServerContext* grpc_context, const dasproto::Empty *request, dasproto::Ack* reply) { + reply->set_msg("PING"); + if (rpc_api_enabled) { + return Status::OK; + } else{ + return Status::CANCELLED; + } +} + +Status AttentionBrokerServer::stimulate(ServerContext* grpc_context, const dasproto::HandleCount *request, dasproto::Ack* reply) { +#ifdef DEBUG + cout << "AttentionBrokerServer::stimulate() BEGIN" << endl; + cout << "Context: " << request->context() << endl; +#endif + if (request->map_size() > 0) { + HebbianNetwork *network = select_hebbian_network(request->context()); + ((dasproto::HandleCount *) request)->set_hebbian_network((long) network); + //this->stimulus_requests->enqueue((void *) request); + this->stimulus_spreader->spread_stimuli(request); + } + reply->set_msg("STIMULATE"); +#ifdef DEBUG + cout << "AttentionBrokerServer::stimulate() END" << endl; +#endif + if (rpc_api_enabled) { + return Status::OK; + } else{ + return Status::CANCELLED; + } +} + +Status AttentionBrokerServer::correlate(ServerContext* grpc_context, const dasproto::HandleList *request, dasproto::Ack* reply) { +#ifdef DEBUG + cout << "AttentionBrokerServer::correlate() BEGIN" << endl; + cout << "Context: " << request->context() << endl; +#endif + if (request->list_size() > 0) { + HebbianNetwork *network = select_hebbian_network(request->context()); + ((dasproto::HandleList *) request)->set_hebbian_network((long) network); + //this->correlation_requests->enqueue((void *) request); + this->updater->correlation(request); + } + reply->set_msg("CORRELATE"); +#ifdef DEBUG + cout << "AttentionBrokerServer::correlate() END" << endl; +#endif + if (rpc_api_enabled) { + return Status::OK; + } else { + return Status::CANCELLED; + } +} + +Status AttentionBrokerServer::get_importance(ServerContext *grpc_context, const dasproto::HandleList *request, dasproto::ImportanceList *reply) { +#ifdef DEBUG + cout << "AttentionBrokerServer::get_importance() BEGIN" << endl; + cout << "Context: " << request->context() << endl; +#endif + if (this->rpc_api_enabled) { + int num_handles = request->list_size(); + if (num_handles > 0) { + HebbianNetwork *network = select_hebbian_network(request->context()); + for (int i = 0; i < num_handles; i++) { + float importance = network->get_node_importance(request->list(i)); + reply->add_list(importance); + } + } +#ifdef DEBUG + cout << "AttentionBrokerServer::get_importance() END" << endl; +#endif + return Status::OK; + } else { + return Status::CANCELLED; + } +} + +// -------------------------------------------------------------------------------- +// Private methods +// + +HebbianNetwork *AttentionBrokerServer::select_hebbian_network(const string &context) { + HebbianNetwork *network; + if ((context != "") && (this->hebbian_network.find(context) != this->hebbian_network.end())) { + network = this->hebbian_network[context]; + } + if (context == "") { + network = this->hebbian_network[this->global_context]; + } else { + if (this->hebbian_network.find(context) == this->hebbian_network.end()) { + network = new HebbianNetwork(); + this->hebbian_network[context] = network; + } else { + network = this->hebbian_network[context]; + } + } + return network; +} diff --git a/src/cpp/attention_broker/AttentionBrokerServer.h b/src/cpp/attention_broker/AttentionBrokerServer.h new file mode 100644 index 0000000..8269559 --- /dev/null +++ b/src/cpp/attention_broker/AttentionBrokerServer.h @@ -0,0 +1,150 @@ +#ifndef _ATTENTION_BROKER_SERVER_ATTENTIONBROKERSERVER_H +#define _ATTENTION_BROKER_SERVER_ATTENTIONBROKERSERVER_H + +#define DEBUG + +#include +#include +#include "attention_broker.grpc.pb.h" +#include "SharedQueue.h" +#include "WorkerThreads.h" +#include "HebbianNetwork.h" +#include "HebbianNetworkUpdater.h" +#include "StimulusSpreader.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; +using dasproto::AttentionBroker; + +namespace attention_broker_server { + +/** + * GRPC server which actually listens to a PORT. + * + * This class implements the GRPC server which listens to a PORT and answer + * the public RPC API defined in the protobuf. + * + * Some parameters related to the stimulus spreading algorithms are also + * stored in this class as static fields. Detailed description of them are + * provided in StimulsSpreader. They are defined here because we may want + * to allow them to be modified by some homeostasis process running independently + * from StimulusSpreading. + */ +class AttentionBrokerServer final: public AttentionBroker::Service { + + public: + + /** + * Basic no-parameters constructor. + * Creates and initializes two request queues. One for stimuli spreading and + * another one for atom correlation. Worker threads to process such queues are + * also created inside a WorkerThread object. + * + * Different contexts are represented using different HebianNetwork objects. New + * contexts are created by caller's request but a default GLOBAL context is created + * here to be used whenever the caller don't specify a context. + */ + AttentionBrokerServer(); + + /** + * Destructor. + * + * Gracefully shutdown the GRPC server by stopping accepting new requests (any new request + * received after starting a shutdown process is denied and an error is returned to the + * caller) and waiting for all requests currently in the queues to be processed. Once all + * queues are empty, all worker threads are stopped and the queues and all other state structures + * are destroyed. + */ + ~AttentionBrokerServer(); + + static const unsigned int WORKER_THREADS_COUNT = 10; /// Number of working threads. + + + // Stimuli spreading parameters + string global_context; + constexpr const static double RENT_RATE = 0.50; /// double in [0..1] range. + constexpr const static double SPREADING_RATE_LOWERBOUND = 0.01; /// double in [0..1] range. + constexpr const static double SPREADING_RATE_UPPERBOUND = 0.10; /// double in [0..1] range. + + // RPC API + + /** + * Used by caller to check if AttentionBroker is running. + * + * @param grpc_context GRPC context object. + * @param request Empty request. + * @param reply The message which will be send back to the caller with a simple ACK. + * + * @return GRPC status OK if request were properly processed or CANCELLED otherwise. + */ + Status ping(ServerContext *grpc_context, const dasproto::Empty *request, dasproto::Ack *reply) override; + + /** + * Spread stimuli according to the passed request. + * + * Boost importance of passed atoms and run one cycle of stimuli spreading. The algorithm is explained + * in StimulusSpreader. + * + * @param grpc_context GRPC context object. + * @param request The request contains a list of handles of the atoms which should have the boost in + * importance as well as an associated integer indicating the proportion in which the boost should + * happen related to the other atoms in the same request. + * @param reply The message which will be send back to the caller with a simple ACK. + * + * @return GRPC status OK if request were properly processed or CANCELLED otherwise. + */ + Status stimulate(ServerContext *grpc_context, const dasproto::HandleCount *request, dasproto::Ack *reply) override; + + /** + * Correlates atoms passed in the request. + * + * @param grpc_context GRPC context object. + * @param request The request contains a list of handles of the atoms which should be correlated. + * @param reply The message which will be send back to the caller with a simple ACK. + * + * @return GRPC status OK if request were properly processed or CANCELLED otherwise. + */ + Status correlate(ServerContext* grpc_context, const dasproto::HandleList* request, dasproto::Ack* reply) override; + + /** + * Return importance of atoms passed in the request. + * + * + * @param grpc_context GRPC context object. + * @param request The request contains a list of handles of the atoms whose importrance are to be returned. + * @param reply A list with importance of atoms IN THE SAME ORDER they appear in the request. + * + * @return GRPC status OK if request were properly processed or CANCELLED otherwise. + */ + Status get_importance(ServerContext *grpc_context, const dasproto::HandleList *request, dasproto::ImportanceList *reply) override; + + // Other public methods + + /** + * Gracefully stop this GRPC server. + * + * Gracefully shutdown the GRPC server by stopping accepting new requests (any new request + * received after starting a shutdown process is denied and an error is returned to the + * caller) and waiting for all requests currently in the queues to be processed. + */ + void graceful_shutdown(); /// Gracefully stop this GRPC server. + + private: + + bool rpc_api_enabled = true; + SharedQueue *stimulus_requests; + SharedQueue *correlation_requests; + WorkerThreads *worker_threads; + unordered_map hebbian_network; + HebbianNetworkUpdater *updater; + StimulusSpreader *stimulus_spreader; + + + HebbianNetwork *select_hebbian_network(const string &context); +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_ATTENTIONBROKERSERVER_H diff --git a/src/cpp/attention_broker/BUILD b/src/cpp/attention_broker/BUILD new file mode 100644 index 0000000..a7aee45 --- /dev/null +++ b/src/cpp/attention_broker/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "attention_broker_server_lib", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + includes = ["."], + deps = [ + "//cpp/utils:utils_lib", + "//cpp/hasher:hasher_lib", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + ], +) diff --git a/src/cpp/attention_broker/HandleTrie.cc b/src/cpp/attention_broker/HandleTrie.cc new file mode 100644 index 0000000..b0ca670 --- /dev/null +++ b/src/cpp/attention_broker/HandleTrie.cc @@ -0,0 +1,242 @@ +#include "Utils.h" +#include "expression_hasher.h" +#include "HandleTrie.h" +#include +#include + +using namespace attention_broker_server; +using namespace commons; + +HandleTrie::TrieValue::TrieValue() { +} + +HandleTrie::TrieValue::~TrieValue() { +} + +HandleTrie::TrieNode::TrieNode() { + children = new TrieNode*[TRIE_ALPHABET_SIZE]; + for (unsigned int i = 0; i < TRIE_ALPHABET_SIZE; i++) { + children[i] = NULL; + } + this->value = NULL; + this->suffix_start = 0; +} + +HandleTrie::TrieNode::~TrieNode() { + for (unsigned int i = 0; i < TRIE_ALPHABET_SIZE; i++) { + delete children[i]; + } + delete [] children; + delete value; +} + +string HandleTrie::TrieValue::to_string() { + return ""; +} + +bool HandleTrie::TLB_INITIALIZED = false; +unsigned char HandleTrie::TLB[256]; + +// -------------------------------------------------------------------------------- +// Public methods + + +string HandleTrie::TrieNode::to_string() { + string answer; + if (suffix_start == 0) { + answer = "''"; + } else { + int n = suffix.size() - suffix_start; + answer = suffix.substr(0, suffix_start) + "." + suffix.substr(suffix_start, n); + } + answer += " ["; + for (unsigned int i = 0; i < TRIE_ALPHABET_SIZE; i++) { + if (children[i] != NULL) { + answer += ("*"); + } + } + answer += "] "; + if (value != NULL) { + answer += value->to_string(); + } + return answer; +} + +HandleTrie::HandleTrie(unsigned int key_size) { + if (key_size == 0 || key_size > 255) { + Utils::error("Invalid key size: " + to_string(key_size)); + } + this->key_size = key_size; + if (! HandleTrie::TLB_INITIALIZED) { + HandleTrie::TLB_INIT(); + } + root = new TrieNode(); +} + +HandleTrie::~HandleTrie() { + delete root; +} + +HandleTrie::TrieValue *HandleTrie::insert(const string &key, TrieValue *value) { + + if (key.size() != key_size) { + Utils::error("Invalid key size: " + to_string(key.size()) + " != " + to_string(key_size)); + } + + TrieNode *tree_cursor = root; + TrieNode *parent = root; + TrieNode *child; + TrieNode *split; + unsigned char key_cursor = 0; + tree_cursor->trie_node_mutex.lock(); + while (true) { + unsigned char c = TLB[(unsigned char) key[key_cursor]]; + if (tree_cursor->children[c] == NULL) { + if (tree_cursor->suffix_start > 0) { + unsigned char c_key_pred = TLB[(unsigned char) key[key_cursor -1]]; + if (key[key_cursor] == tree_cursor->suffix[key_cursor]) { + child = new TrieNode(); + child->trie_node_mutex.lock(); + child->children[c] = tree_cursor; + tree_cursor->suffix_start++; + parent->children[c_key_pred] = child; + parent->trie_node_mutex.unlock(); + parent = child; + key_cursor++; + } else { + child = new TrieNode(); + child->suffix = key; + child->suffix_start = key_cursor + 1; + child->value = value; + unsigned char c_tree_cursor = TLB[(unsigned char) tree_cursor->suffix[tree_cursor->suffix_start]]; + tree_cursor->suffix_start++; + split = new TrieNode(); + split->children[c] = child; + split->children[c_tree_cursor] = tree_cursor; + parent->children[c_key_pred] = split; + parent->trie_node_mutex.unlock(); + if (tree_cursor != parent) { + tree_cursor->trie_node_mutex.unlock(); + } + return child->value; + } + } else { + child = new TrieNode(); + child->suffix = key; + child->suffix_start = key_cursor + 1; + child->value = value; + tree_cursor->children[c] = child; + parent->trie_node_mutex.unlock(); + if (tree_cursor != parent) { + tree_cursor->trie_node_mutex.unlock(); + } + return child->value; + } + } else { + if (tree_cursor != parent) { + parent->trie_node_mutex.unlock(); + parent = tree_cursor; + } + tree_cursor = tree_cursor->children[c]; + tree_cursor->trie_node_mutex.lock(); + if (tree_cursor->suffix_start > 0) { + bool match = true; + unsigned int n = key.size(); + for (unsigned int i = key_cursor; i < n; i++) { + if (key[i] != tree_cursor->suffix[i]) { + match = false; + break; + } + } + if (match) { + tree_cursor->value->merge(value); + delete value; + if (tree_cursor != parent) { + parent->trie_node_mutex.unlock(); + } + tree_cursor->trie_node_mutex.unlock(); + return tree_cursor->value; + } + } + key_cursor++; + } + } +} + +HandleTrie::TrieValue *HandleTrie::lookup(const string &key) { + + if (key.size() != key_size) { + Utils::error("Invalid key size: " + to_string(key.size()) + " != " + to_string(key_size)); + } + + TrieNode *tree_cursor = root; + TrieValue *value; + unsigned char key_cursor = 0; + tree_cursor->trie_node_mutex.lock(); + while (tree_cursor != NULL) { + if (tree_cursor->suffix_start > 0) { + bool match = true; + unsigned int n = key.size(); + for (unsigned int i = key_cursor; i < n; i++) { + if (key[i] != tree_cursor->suffix[i]) { + match = false; + break; + } + } + if (match) { + value = tree_cursor->value; + } else { + value = NULL; + } + tree_cursor->trie_node_mutex.unlock(); + return value; + } else { + unsigned char c = TLB[(unsigned char) key[key_cursor]]; + TrieNode *child = tree_cursor->children[c]; + tree_cursor->trie_node_mutex.unlock(); + tree_cursor = child; + key_cursor++; + if (tree_cursor != NULL) { + tree_cursor->trie_node_mutex.lock(); + } + } + } + return NULL; +} + +void HandleTrie::traverse(bool keep_root_locked, bool (*visit_function)(TrieNode *node, void *data), void *data) { + + stack node_stack; + TrieNode *cursor; + node_stack.push(root); + + while (! node_stack.empty()) { + cursor = node_stack.top(); + node_stack.pop(); + cursor->trie_node_mutex.lock(); + if (cursor->suffix_start > 0) { + if (visit_function(cursor, data)) { + if (keep_root_locked && (root != cursor)) { + root->trie_node_mutex.unlock(); + } + cursor->trie_node_mutex.unlock(); + return; + } + } else { + for (unsigned int i = TRIE_ALPHABET_SIZE - 1; ; i--) { + if (cursor->children[i] != NULL) { + node_stack.push(cursor->children[i]); + } + if (i == 0) { + break; + } + } + } + if ((! keep_root_locked) || (cursor != root)) { + cursor->trie_node_mutex.unlock(); + } + } + if (keep_root_locked) { + root->trie_node_mutex.unlock(); + } +} diff --git a/src/cpp/attention_broker/HandleTrie.h b/src/cpp/attention_broker/HandleTrie.h new file mode 100644 index 0000000..34e2c97 --- /dev/null +++ b/src/cpp/attention_broker/HandleTrie.h @@ -0,0 +1,119 @@ +#ifndef _ATTENTION_BROKER_SERVER_HANDLETRIE_H +#define _ATTENTION_BROKER_SERVER_HANDLETRIE_H + +#include +#include + +#define TRIE_ALPHABET_SIZE ((unsigned int) 16) + +using namespace std; + +namespace attention_broker_server { + +/** + * Data abstraction implementing a map handle->value using a trie (prefix tree). + * + * This data structure is basicaly like a hashmap mapping from handles to objects of + * type HandleTrie::TrieValue. + * + * When a (key, value) pair is inserted and key is already present, the method merge() + * in value is called passing the newly inserted value. + */ +class HandleTrie { + +public: + + /** + * Virtual basic class to be extended by objects used as "values". + */ + class TrieValue { + protected: + TrieValue(); /// basic empty constructor. + public: + virtual ~TrieValue(); /// Destructor. + virtual void merge(TrieValue *other) = 0; /// Called when a repeated handle is inserted. + virtual string to_string(); /// Returns a string representation of the value object. + + }; + + /** + * A node in the prefix tree used to store keys. + */ + class TrieNode { + public: + TrieNode(); /// Basic empty constructor. + ~TrieNode(); /// Destructor. + + TrieNode **children; /// Array with children of this node. + TrieValue *value; /// Value attached to this node or NULL if none. + string suffix; /// The key (handle) attached to this node (leafs) or NULL if none (internal nodes). + unsigned char suffix_start; /// The point in the suffix from which this node (leaf) differs from its siblings. + mutex trie_node_mutex; + + string to_string(); /// Returns a string representation of this node. + }; + + HandleTrie(unsigned int key_size); /// Basic constructor. + ~HandleTrie(); /// Destructor. + + /** + * Insert a new key in this HandleTrie or merge its value if the key is already present. + * + * @param key Handle being inserted. + * @param value HandleTrie::TrieValue object being inserted. + * + * @return The resulting HandleTrie::TrieValue object after insertion (and eventually the merge) is processed. + */ + TrieValue *insert(const string &key, TrieValue *value); + + /** + * Lookup for a given handle. + * + * @param key Handle being searched. + * + * @return The HandleTrie::TrieValue object attached to the passed key or NULL if none. + */ + TrieValue *lookup(const string &key); + + /** + * Traverse all keys (in-order) calling the passed visit_function once per stored value. + * + * @param keep_root_locked Keep root HandleTrie::TrieNode locked during the whole traversing + * releasing the lock when it ends. + * @param visit_function Function to be called when each value is visited. + * @param data Additional information passed to visit_function or NULL if none. + */ + void traverse(bool keep_root_locked, bool (*visit_function)(TrieNode *node, void *data), void *data); + + TrieNode *root; + +private: + + static unsigned char TLB[256]; + static bool TLB_INITIALIZED; + static void TLB_INIT() { + TLB[(unsigned char) '0'] = 0; + TLB[(unsigned char) '1'] = 1; + TLB[(unsigned char) '2'] = 2; + TLB[(unsigned char) '3'] = 3; + TLB[(unsigned char) '4'] = 4; + TLB[(unsigned char) '5'] = 5; + TLB[(unsigned char) '6'] = 6; + TLB[(unsigned char) '7'] = 7; + TLB[(unsigned char) '8'] = 8; + TLB[(unsigned char) '9'] = 9; + TLB[(unsigned char) 'a'] = TLB[(unsigned char) 'A'] = 10; + TLB[(unsigned char) 'b'] = TLB[(unsigned char) 'B'] = 11; + TLB[(unsigned char) 'c'] = TLB[(unsigned char) 'C'] = 12; + TLB[(unsigned char) 'd'] = TLB[(unsigned char) 'D'] = 13; + TLB[(unsigned char) 'e'] = TLB[(unsigned char) 'E'] = 14; + TLB[(unsigned char) 'f'] = TLB[(unsigned char) 'F'] = 15; + TLB_INITIALIZED = true; + } + + unsigned int key_size; +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_HANDLETRIE_H diff --git a/src/cpp/attention_broker/HebbianNetwork.cc b/src/cpp/attention_broker/HebbianNetwork.cc new file mode 100644 index 0000000..62042f1 --- /dev/null +++ b/src/cpp/attention_broker/HebbianNetwork.cc @@ -0,0 +1,120 @@ +#include +#include +#include "HebbianNetwork.h" +#include "Utils.h" +#include "expression_hasher.h" + +using namespace attention_broker_server; + +// -------------------------------------------------------------------------------- +// Public methods + +HebbianNetwork::HebbianNetwork() { + nodes = new HandleTrie(HANDLE_HASH_SIZE - 1); + largest_arity = 0; + tokens_mutex.lock(); + tokens_to_distribute = 1.0; + tokens_mutex.unlock(); +} + +HebbianNetwork::~HebbianNetwork() { + delete nodes; +} + +string HebbianNetwork::Node::to_string() { + return "(" + std::to_string(count) + ", " + std::to_string(importance) + ", " + std::to_string(arity) + ")"; +} + +string HebbianNetwork::Edge::to_string() { + return "(" + std::to_string(count) + ")"; +} + +HebbianNetwork::Node *HebbianNetwork::add_node(string handle) { + return (Node *) nodes->insert(handle, new HebbianNetwork::Node()); +} + +HebbianNetwork::Edge *HebbianNetwork::add_asymmetric_edge(string handle1, string handle2, Node *node1, Node *node2) { + if (node1 == NULL) { + node1 = (Node *) nodes->lookup(handle1); + } + Edge *edge = (Edge *) node1->neighbors->insert(handle2, new HebbianNetwork::Edge()); + if (edge->count == 1) { + // First time this edge is added + edge->node1 = node1; + edge->node2 = node2; + node1->arity += 1; + largest_arity_mutex.lock(); + if (node1->arity > largest_arity) { + largest_arity = node1->arity; + } + largest_arity_mutex.unlock(); + } + return edge; +} + +void HebbianNetwork::add_symmetric_edge(string handle1, string handle2, Node *node1, Node *node2) { + add_asymmetric_edge(handle1, handle2, node1, node2); + add_asymmetric_edge(handle2, handle1, node2, node1); +} + +HebbianNetwork::Node *HebbianNetwork::lookup_node(string handle) { + return (Node *) nodes->lookup(handle); +} + +unsigned int HebbianNetwork::get_node_count(string handle) { + Node *node = (Node *) nodes->lookup(handle); + if (node == NULL) { + return 0; + } else { + return node->count; + } +} + +ImportanceType HebbianNetwork::get_node_importance(string handle) { + Node *node = (Node *) nodes->lookup(handle); + if (node == NULL) { + return 0; + } else { + return node->importance; + } +} + +unsigned int HebbianNetwork::get_asymmetric_edge_count(string handle1, string handle2) { + Node *source = (Node *) nodes->lookup(handle1); + if (source != NULL) { + Edge *edge = (Edge *) source->neighbors->lookup(handle2); + if (edge != NULL) { + return edge->count; + } + } + return 0; +} + +ImportanceType HebbianNetwork::alienate_tokens() { + ImportanceType answer; + tokens_mutex.lock(); + answer = tokens_to_distribute; + tokens_to_distribute = 0.0; + tokens_mutex.unlock(); + return answer; +} + +void HebbianNetwork::visit_nodes( + bool keep_root_locked, + bool (*visit_function)(HandleTrie::TrieNode *node, void *data), + void *data) { + + nodes->traverse(keep_root_locked, visit_function, data); +} + +static inline void release_locks( + HandleTrie::TrieNode *root, + HandleTrie::TrieNode *cursor, + bool keep_root_locked, + bool release_root_after_end) { + + if (keep_root_locked && release_root_after_end && (root != cursor)) { + root->trie_node_mutex.unlock(); + } + cursor->trie_node_mutex.unlock(); +} diff --git a/src/cpp/attention_broker/HebbianNetwork.h b/src/cpp/attention_broker/HebbianNetwork.h new file mode 100644 index 0000000..37f8149 --- /dev/null +++ b/src/cpp/attention_broker/HebbianNetwork.h @@ -0,0 +1,188 @@ +#ifndef _ATTENTION_BROKER_SERVER_HEBBIANNETWORK_H +#define _ATTENTION_BROKER_SERVER_HEBBIANNETWORK_H + +#include +#include +#include +#include "HandleTrie.h" +#include "expression_hasher.h" + +using namespace std; + +namespace attention_broker_server { + +typedef double ImportanceType; + +/** + * Data abstraction of an asymmetric Hebbian Network with only direct hebbian links. + * + * A Hebbian Network, in the context of AttentionBrokerServer, is a directed graph with + * weights in the edges A->B representing the probability of B being present in a DAS + * query answer given that A is present. In other words, provided that the atom A is one of + * the atoms returned in a given query in DAS, the weight of the edge A->B in + * a HebbianNetwork is an estimate of the probability of B being present in the same answer. + * + * HebbianNetwork is asymmetric, meaning that the weight of A->B can be different from the + * weight of B->A. + * + * All edges in HebbianNetwork are "direct" hebbian links meaning that they estimate the + * probability of B BEING present in a answer given that A is. There are no "reverse" + * hebbian links which would mean the probability of B being NOT present in an answer + * given that A is. + * + * HebbianNetwork keeps nodes in a HandleTrie which is basically a data structure to map + * from handle to a value object (mostly like a hashmap but slightly more efficient because + * it makes the assumption that the key is a handle). So in the stored value object it stores + * another HandleTrie to keep track of the neighbors. So we have two types of value objects, + * one to represent Nodes and another one to represent Edges. + */ +class HebbianNetwork { + +public: + + HebbianNetwork(); /// Basic constructor. + ~HebbianNetwork(); /// Destructor. + + unsigned int largest_arity; /// Largest arity among nodes in this network. + mutex largest_arity_mutex; + + // Node and Link don't inherit from a common "Atom" class to avoid having virtual methods, + // which couldn't be properly inlined. + + /** + * Node object used as the value in HandleTrie. + */ + class Node: public HandleTrie::TrieValue { + public: + unsigned int arity; /// Number of neighbors of this Node. + unsigned int count; /// Count for this Node. + ImportanceType importance; /// Importance of this Node. + ImportanceType stimuli_to_spread; /// Amount of importance this node will spread in the next + /// stimuli spreading cycle. + HandleTrie *neighbors; // Neighbors of this Node. + Node() { + arity = 0; + count = 1; + importance = 0.0; + stimuli_to_spread = 0.0; + neighbors = new HandleTrie(HANDLE_HASH_SIZE - 1); + } + inline void merge(HandleTrie::TrieValue *other) { + count += ((Node *) other)->count; + importance += ((Node *) other)->importance; + } + string to_string(); /// String representation of this Node. + }; + + /** + * Edge object used as the value in HandleTrie. + */ + class Edge: public HandleTrie::TrieValue { + public: + unsigned int count; /// Count for this edge. + Node *node1; /// Source Node. + Node *node2; /// Target node. + Edge() { + count = 1; + node1 = node2 = NULL; + } + inline void merge(HandleTrie::TrieValue *other) { + count += ((Edge *) other)->count; + } + string to_string(); /// String representation of this Edge. + }; + + /** + * Adds a new node to this network or increase +1 to its count if it already exists. + * + * @param handle Atom being added. + * + * @return the value object attached to the node being inserted. + */ + Node *add_node(string handle); + + /** + * Adds a new edge handle1->handle2 to this network or increase +1 to its count if it already exists. + * + * @param handle1 Source of the edge. + * @param handle2 Target of the edge. + * + * @return the value object attached to the edge being inserted. + */ + Edge *add_asymmetric_edge(string handle1, string handle2, Node *node1, Node *node2); + + /** + * Adds new edges handle1->handle2 and handle2->handle1 to this network or increase +1 + * to their count if they already exist. + * + * @param handle1 One of the nodes in the edge. + * @param handle2 The other node in the edge. + */ + void add_symmetric_edge(string handle1, string handle2, Node *node1, Node *node2); + + /** + * Lookup and return the value attached to the passed handle. + * + * @param handle Handle of the node being searched. + * + * @return The value object attached to the node. + */ + Node *lookup_node(string handle); + + /** + * Lookup for the passed node and return its count. + * + * @param handle Handle of the node being searched. + * + * @return The count of the passed node. + */ + unsigned int get_node_count(string handle); + + /** + * Lookup for the passed node and return its importance. + * + * @param handle Handle of the node being searched. + * + * @return The importance of the passed node. + */ + ImportanceType get_node_importance(string handle); + + /** + * Lookup for the passed edge and return its count. + * + * @param source Source of the edge being searched. + * @param target Target of the edge being searched. + * + * @returnThe count of the passed edge. + */ + unsigned int get_asymmetric_edge_count(string handle1, string handle2); + + /** + * Traverse the node's HandleTrie and call the passed function once for each + * of the visited nodes. + * + * @param keep_root_locked True iff HandleTrie root should be kept locked during + * all the traversal. If false, the root lock is freed just like any other internal + * trie node. + * @param visit_function Function to be called passing each visited node. This function + * @param data Additional data to be passed to the visit_function. + * + * expects the Node and a pointer to eventual data used inside visit_function. + */ + void visit_nodes( + bool keep_root_locked, + bool (*visit_function)(HandleTrie::TrieNode *node, void *data), + void *data); + + ImportanceType alienate_tokens(); + +private: + + HandleTrie *nodes; + ImportanceType tokens_to_distribute; + mutex tokens_mutex; +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_HEBBIANNETWORK_H diff --git a/src/cpp/attention_broker/HebbianNetworkUpdater.cc b/src/cpp/attention_broker/HebbianNetworkUpdater.cc new file mode 100644 index 0000000..c9a954a --- /dev/null +++ b/src/cpp/attention_broker/HebbianNetworkUpdater.cc @@ -0,0 +1,55 @@ +#include +#include "HebbianNetworkUpdater.h" +#include "HebbianNetwork.h" +#include "Utils.h" + +using namespace attention_broker_server; +using namespace commons; + +HebbianNetworkUpdater::HebbianNetworkUpdater() { +} + +// -------------------------------------------------------------------------------- +// Public methods + +HebbianNetworkUpdater::~HebbianNetworkUpdater() { +} + +HebbianNetworkUpdater *HebbianNetworkUpdater::factory(HebbianNetworkUpdaterType instance_type) { + switch (instance_type) { + case HebbianNetworkUpdaterType:: EXACT_COUNT: { + return new ExactCountHebbianUpdater(); + } + default: { + Utils::error("Invalid HebbianNetworkUpdaterType: " + to_string((int) instance_type)); + return NULL; // to avoid warnings + } + } + +} + +ExactCountHebbianUpdater::ExactCountHebbianUpdater() { +} + +ExactCountHebbianUpdater::~ExactCountHebbianUpdater() { +} + +void ExactCountHebbianUpdater::correlation(const dasproto::HandleList *request) { + HebbianNetwork *network = (HebbianNetwork *) request->hebbian_network(); + if (network != NULL) { + for (const string &s: ((dasproto::HandleList *) request)->list()) { + network->add_node(s); + } + HebbianNetwork::Node *node1; + HebbianNetwork::Node *node2; + for (const string &s1: ((dasproto::HandleList *) request)->list()) { + node1 = network->lookup_node(s1); + for (const string &s2: ((dasproto::HandleList *) request)->list()) { + if (s1.compare(s2) < 0) { + node2 = network->lookup_node(s2); + network->add_symmetric_edge(s1, s2, node1, node2); + } + } + } + } +} diff --git a/src/cpp/attention_broker/HebbianNetworkUpdater.h b/src/cpp/attention_broker/HebbianNetworkUpdater.h new file mode 100644 index 0000000..3862052 --- /dev/null +++ b/src/cpp/attention_broker/HebbianNetworkUpdater.h @@ -0,0 +1,101 @@ +#ifndef _ATTENTION_BROKER_SERVER_HEBBIANNETWORKUPDATER_H +#define _ATTENTION_BROKER_SERVER_HEBBIANNETWORKUPDATER_H + +#include "attention_broker.grpc.pb.h" + +using namespace std; + +namespace attention_broker_server { + +/** + * Algorithm used to update HebbianNetwork weights in "correlate" requests. + */ +enum class HebbianNetworkUpdaterType { + EXACT_COUNT /// Tracks counts of nodes and links computing actual weights on demand. +}; + +/** + * Process correlation requests by changing the weights in the passed HebbianNetwork + * to reflect the evidence provided in the request. + * + * Objects of this class are used by worker threads to process "correlation" requests. + * + * The request have a list of handles of atoms which appeared together in the same query + * answer. So it's a positive evidence of correlation among such atoms. The HebbianNetwork + * weights are changed to reflect this evidence. + * + * This is an abstract class. Concrete subclasses implement different ways of computing + * weights in HebbianNetwork. + */ +class HebbianNetworkUpdater { + +public: + + /** + * Factory method. + * + * Factory method to instantiate concrete subclasses according to the passed parameter. + * + * @param instance_type Type of concrete subclass to be instantiated. + * + * @return An object of the passed type. + */ + static HebbianNetworkUpdater *factory(HebbianNetworkUpdaterType instance_type); + virtual ~HebbianNetworkUpdater(); /// Destructor. + + /** + * Process a correlation evidence. + * + * The evidence is used to update the weights in the HebbianNetwork. The actual way these + * weights are updated depends on the type of the concrete subclass that implements this method. + * + * @param request A list of handles of atoms which appeared in the same query answer. + */ + virtual void correlation(const dasproto::HandleList* request) = 0; + +protected: + + HebbianNetworkUpdater(); /// Basic empty constructor. + +private: + +}; + +/** + * Process correlation requests by changing the weights in the passed HebbianNetwork + * to reflect the evidence provided in the request. + * + * Objects of this class are used by worker threads to process "correlation" requests. + * + * The request have a list of handles of atoms which appeared together in the same query + * answer. So it's a positive evidence of correlation among such atoms. The HebbianNetwork + * weights are changed to reflect this evidence. + * + * This HebbianNetworkUpdater keeps track of actual counts of atoms and symmetric hebbian + * links between them as they appear in correlation evidence (requests). Actual hebbian weights + * between A -> B are calculated on demand by dividing count(A->B) / count(A). + */ +class ExactCountHebbianUpdater: public HebbianNetworkUpdater { + +public: + + ExactCountHebbianUpdater(); /// Basic empty constructor. + ~ExactCountHebbianUpdater(); /// Destructor. + + /** + * Process a correlation evidence. + * + * The evidence is used to update the weights in the HebbianNetwork. + * + * This HebbianNetworkUpdater keeps track of actual counts of atoms and symmetric hebbian + * links between them as they appear in correlation evidence (requests). Actual hebbian weights + * between A -> B are calculated on demand by dividing count(A->B) / count(A). + * + * @param request A list of handles of atoms which appeared in the same query answer. + */ + void correlation(const dasproto::HandleList *request); /// Process a correlation evidence. +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_HEBBIANNETWORKUPDATER_H diff --git a/src/cpp/attention_broker/RequestSelector.cc b/src/cpp/attention_broker/RequestSelector.cc new file mode 100644 index 0000000..e589de1 --- /dev/null +++ b/src/cpp/attention_broker/RequestSelector.cc @@ -0,0 +1,61 @@ +#include "RequestSelector.h" +#include "Utils.h" +#include + +using namespace attention_broker_server; + +// -------------------------------------------------------------------------------- +// Public methods + +RequestSelector::RequestSelector( + unsigned int thread_id, + SharedQueue *stimulus, + SharedQueue *correlation) { + + this->thread_id = thread_id; + this->stimulus = stimulus; + this->correlation = correlation; +} + +RequestSelector::~RequestSelector() { +} + +EvenThreadCount::EvenThreadCount( + unsigned int thread_id, + SharedQueue *stimulus, + SharedQueue *correlation) : RequestSelector(thread_id, stimulus, correlation) { + + even_thread_id = ((thread_id % 2) == 0); +} + +EvenThreadCount::~EvenThreadCount() { +} + +RequestSelector *RequestSelector::factory( + SelectorType instance_type, + unsigned int thread_id, + SharedQueue *stimulus, + SharedQueue *correlation) { + + switch (instance_type) { + case SelectorType::EVEN_THREAD_COUNT: { + return new EvenThreadCount(thread_id, stimulus, correlation); + } + default: { + Utils::error("Invalid selector type: " + to_string((int) instance_type)); + return NULL; // to avoid warnings + } + } +} + +pair EvenThreadCount::next() { + pair answer; + if (even_thread_id) { + answer.first = RequestType::STIMULUS; + answer.second = (void *) stimulus->dequeue(); + } else { + answer.first = RequestType::CORRELATION; + answer.second = (void *) correlation->dequeue(); + } + return answer; +} diff --git a/src/cpp/attention_broker/RequestSelector.h b/src/cpp/attention_broker/RequestSelector.h new file mode 100644 index 0000000..85aaa36 --- /dev/null +++ b/src/cpp/attention_broker/RequestSelector.h @@ -0,0 +1,93 @@ +#ifndef _ATTENTION_BROKER_SERVER_REQUESTSELECTOR_H +#define _ATTENTION_BROKER_SERVER_REQUESTSELECTOR_H + +#include "HebbianNetwork.h" +#include "SharedQueue.h" + +namespace attention_broker_server { +using namespace std; +using namespace commons; + +enum class SelectorType { + EVEN_THREAD_COUNT +}; + +enum class RequestType { + STIMULUS, + CORRELATION +}; + +/** + * Abstract class used in WorkerThreads to select the next request to be processed among + * the available request queues. + * + * Concrete subclasses may implement different selection algorithms based in different criteria. + */ +class RequestSelector { + +public: + + virtual ~RequestSelector(); /// Destructor. + + /** + * Factory method. + * + * Factory method to instantiate concrete subclasses according to the passed parameter. + * + * @param instance_type Type of concrete subclass to be instantiated. + * @param thread_id ID of the thread asking for a new request. + * @param stimulus Queue of "stimulate" requests. + * @param correlation Queue of "correlate" requests. + * + * @return An object of the passed type. + */ + static RequestSelector *factory( + SelectorType instance_type, + unsigned int thread_id, + SharedQueue *stimulus, + SharedQueue *correlation); + + /** + * Return the next request to be processed by the caller worker thread. + * + * @return the next request to be processed by the caller worker thread. + */ + virtual pair next() = 0; + +protected: + + RequestSelector(unsigned int thread_id, SharedQueue *stimulus, SharedQueue *correlation); /// Basic constructor. + + unsigned int thread_id; + SharedQueue *stimulus; + SharedQueue *correlation; +}; + +/** + * Concrete implementation of RequestSelector which evenly distribute worker threads among each type of request. + * + * This selector keeps half of the working threads working only in "correlate" requests and the other + * half working only in "stimulate" requests. + */ +class EvenThreadCount : public RequestSelector { + +public: + + ~EvenThreadCount(); /// Destructor. + EvenThreadCount(unsigned int thread_id, SharedQueue *stimulus, SharedQueue *correlation); /// Basic constructor. + + /** + * Return the next request to be processed by the caller worker thread. + * + * @return the next request to be processed by the caller worker thread. + */ + pair next(); + +private: + + bool even_thread_id; +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_REQUESTSELECTOR_H diff --git a/src/cpp/attention_broker/StimulusSpreader.cc b/src/cpp/attention_broker/StimulusSpreader.cc new file mode 100644 index 0000000..d269bc8 --- /dev/null +++ b/src/cpp/attention_broker/StimulusSpreader.cc @@ -0,0 +1,150 @@ +#include "expression_hasher.h" +#include "StimulusSpreader.h" +#include "HebbianNetwork.h" +#include "AttentionBrokerServer.h" +#include "Utils.h" +#include +#include + +using namespace attention_broker_server; + +// -------------------------------------------------------------------------------- +// Public constructors and destructors + +StimulusSpreader::~StimulusSpreader() { +} + +StimulusSpreader::StimulusSpreader() { +} + +StimulusSpreader *StimulusSpreader::factory(StimulusSpreaderType instance_type) { + switch (instance_type) { + case StimulusSpreaderType::TOKEN : { + return new TokenSpreader(); + } + default: { + Utils::error("Invalid StimulusSpreaderType: " + to_string((int) instance_type)); + return NULL; // to avoid warnings + } + } + +} + +TokenSpreader:: TokenSpreader() { +} + +TokenSpreader:: ~TokenSpreader() { +} + +// ------------------------------------------------ +// "visit" functions used to traverse network + +typedef TokenSpreader::StimuliData DATA; + +static bool collect_rent(HandleTrie::TrieNode *node, void *data) { + ImportanceType rent = ((DATA *) data)->rent_rate * \ + ((HebbianNetwork::Node *) node->value)->importance; + ((DATA *) data)->total_rent += rent; + ImportanceType wages = 0.0; + ((DATA *) data)->importance_changes->insert( + node->suffix, + new TokenSpreader::ImportanceChanges(rent, wages)); + return false; +} + +static bool consolidate_rent_and_wages(HandleTrie::TrieNode *node, void *data) { + + HebbianNetwork::Node *value = (HebbianNetwork::Node *) node->value; + + TokenSpreader::ImportanceChanges *changes =\ + (TokenSpreader::ImportanceChanges *) ((DATA *) data)->importance_changes->lookup(node->suffix); + value->importance -= changes->rent; + value->importance += changes->wages; + + // Compute amount to be spread + ImportanceType arity_ratio = (double) value->arity / ((DATA *) data)->largest_arity; + ImportanceType spreading_rate = ((DATA *) data)->spreading_rate_lowerbound + \ + (((DATA *) data)->spreading_rate_range_size * \ + arity_ratio); + ImportanceType to_spread = value->importance * spreading_rate; + value->importance -= to_spread; + value->stimuli_to_spread = to_spread; + return false; +} + +static bool sum_weights(HandleTrie::TrieNode *node, void *data) { + HebbianNetwork::Edge *edge = (HebbianNetwork::Edge *) node->value; + double w = (double) edge->count / edge->node1->count; + ((DATA *) data)->sum_weights += w; + return false; +} + +static bool deliver_stimulus(HandleTrie::TrieNode *node, void *data) { + HebbianNetwork::Edge *edge = (HebbianNetwork::Edge *) node->value; + double w = (double) edge->count / edge->node1->count; + ImportanceType stimulus = (w / ((DATA *) data)->sum_weights) * ((DATA *) data)->to_spread; + edge->node2->importance += stimulus; + return false; +} + +static bool consolidate_stimulus(HandleTrie::TrieNode *node, void *data) { + HebbianNetwork::Node *value = (HebbianNetwork::Node *) node->value; + ((DATA *) data)->to_spread = value->stimuli_to_spread; + ((DATA *) data)->sum_weights = 0.0; + value->neighbors->traverse(true, &sum_weights, data); + value->neighbors->traverse(true, &deliver_stimulus, data); + value->stimuli_to_spread = 0.0; + return false; +} + +// ------------------------------------------------ +// Public methods + +void TokenSpreader::distribute_wages( + const dasproto::HandleCount *handle_count, + ImportanceType &total_to_spread, + DATA *data) { + + auto iterator = handle_count->map().find("SUM"); + if (iterator == handle_count->map().end()) { + Utils::error("Missing 'SUM' key in HandleCount request"); + } + unsigned int total_wages = iterator->second; + for (auto pair: handle_count->map()) { + if (pair.first != "SUM") { + double normalized_amount = (((double) pair.second) * total_to_spread) / total_wages; + data->importance_changes->insert(pair.first, new TokenSpreader::ImportanceChanges(0.0, normalized_amount)); + } + } +} + +void TokenSpreader::spread_stimuli(const dasproto::HandleCount *request) { + + HebbianNetwork *network = (HebbianNetwork *) request->hebbian_network(); + if (network == NULL) { + return; + } + + DATA data; + data.importance_changes = new HandleTrie(HANDLE_HASH_SIZE - 1); + data.rent_rate = AttentionBrokerServer::RENT_RATE; + data.spreading_rate_lowerbound = AttentionBrokerServer::SPREADING_RATE_LOWERBOUND; + data.spreading_rate_range_size = \ + AttentionBrokerServer::SPREADING_RATE_UPPERBOUND - AttentionBrokerServer::SPREADING_RATE_LOWERBOUND; + data.largest_arity = network->largest_arity; + data.total_rent = 0.0; + + // Collect rent + network->visit_nodes(true, &collect_rent, (void *) &data); + + // Distribute wages + ImportanceType total_to_spread = network->alienate_tokens(); + total_to_spread += data.total_rent; + distribute_wages(request, total_to_spread, &data); + + // Consolidate changes + network->visit_nodes(true, &consolidate_rent_and_wages, (void *) &data); + + // Spread activation (1 cycle) + network->visit_nodes(true, &consolidate_stimulus, &data); +} diff --git a/src/cpp/attention_broker/StimulusSpreader.h b/src/cpp/attention_broker/StimulusSpreader.h new file mode 100644 index 0000000..5a9ae2e --- /dev/null +++ b/src/cpp/attention_broker/StimulusSpreader.h @@ -0,0 +1,143 @@ +#ifndef _ATTENTION_BROKER_SERVER_STIMULUSSPREADER_H +#define _ATTENTION_BROKER_SERVER_STIMULUSSPREADER_H + +#include "attention_broker.grpc.pb.h" +#include "Utils.h" +#include "HebbianNetwork.h" + +using namespace std; + +namespace attention_broker_server { + +/** + * Algorithm used to update HebbianNetwork weights in "stimulate" requests. + */ +enum class StimulusSpreaderType { + TOKEN /// Consider importance as a fixed amount of tokens distributed among atoms in the HebbianNetwork. +}; + +/** + * Process stimuli spreading requests by boosting importance of the atoms passed in the request + * and running one cycle of stimuli spreading in the Hebbian Network. + * + * Objects of this class are used by worker threads to process "stimulate" requests. + * + * The request have a list of pairs (handle, n) with the handles whose importance should be + * boosted and the relative magnitude of this boost (compared to the other handles in the same + * request). + * + * This is an abstract class. Concrete subclasses implement different ways of spreading stimuli + * in the HebbianNetwork. + * + */ +class StimulusSpreader { + +public: + + /** + * Factory method. + * + * Factory method to instantiate concrete subclasses according to the passed parameter. + * + * @param instance_type Type of concrete subclass to be instantiated. + * + * @return An object of the passed type. + */ + static StimulusSpreader *factory(StimulusSpreaderType instance_type); + virtual ~StimulusSpreader(); /// destructor. + + /** + * Stimulate atoms and run one cycle of stimuli spreading. + * + * Atoms in the passed list have their importance boosted according to the passed counts. Then + * one cycle of stimuli spreading is executed. + * + * @param request A list of handles to be boosted and respective counts which are used to determine + * the magnitude of such boost. The actual way importance is boosted and then spread among HebbianNetwork + * links are delegated to the concrete subclasses. + */ + virtual void spread_stimuli(const dasproto::HandleCount *request) = 0; + +protected: + + StimulusSpreader(); /// Basic empty constructor. + +private: + +}; + +/** + * Process stimuli spreading requests by boosting importance of the atoms passed in the request + * and running one cycle of stimuli spreading in the Hebbian Network. + * + * Objects of this class are used by worker threads to process "stimulate" requests. + * + * The request have a list of pairs (handle, n) with the handles whose importance should be + * boosted and the relative magnitude of this boost (compared to the other handles in the same + * request). + * + * This StimulusSpreader consider a fixed amount of tokens distributed among all atoms in the + * HebbianNetwork. Importance boosts and stimulus spreading are implemented in a way that this + * total amount of tokens remains fixed, unless explicitly requested by caller. + */ +class TokenSpreader: public StimulusSpreader { + +public: + + TokenSpreader(); /// Basic empty constructor. + ~TokenSpreader(); /// Destructor. + + // data structure used as parameter container in "visit" functions + // used in trie traversal + typedef struct { + ImportanceType rent_rate; + ImportanceType total_rent; + HandleTrie *importance_changes; + unsigned int largest_arity; + ImportanceType spreading_rate_lowerbound; + ImportanceType spreading_rate_range_size; + ImportanceType to_spread; + double sum_weights; + } StimuliData; + + // data structure used in a private trie during importance update calculations + class ImportanceChanges: public HandleTrie::TrieValue { + public: + ImportanceType rent; + ImportanceType wages; + ImportanceChanges(ImportanceType r, ImportanceType w) { + rent = r; + wages = w; + } + void merge(TrieValue *other) { + rent += ((ImportanceChanges *) other)->rent; + wages += ((ImportanceChanges *) other)->wages; + } + }; + + /** + * Stimulate atoms and run one cycle of stimuli spreading. + * + * Atoms in the passed list have their importance boosted according to the passed counts. Then + * one cycle of stimuli spreading is executed. + * + * Boosts and stimuli spreading are actually tokens which are collected from all the nodes in the + * HebbianNetwork (as a rent) and redistributed according to the passed * counts (as wages). Once + * rents and wages are consolidated in each node's importance, one cycle of stimuli spreading is run + * when a % of the importance tokens of each node being redistributed to amnongst its neighbors + * according to the weights of the links in the HebbianNetwork. + * + * @param request A list of handles to be boosted and respective counts which are used to determine + * the magnitude of such boost. + */ + void spread_stimuli(const dasproto::HandleCount *request); + + // Used only in "visit" functions during trie traversals. Such functions aren't methods so this method + // must be public. + void distribute_wages(const dasproto::HandleCount *handle_count, ImportanceType &total_to_spread, StimuliData *data); + +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_STIMULUSSPREADER_H diff --git a/src/cpp/attention_broker/WorkerThreads.cc b/src/cpp/attention_broker/WorkerThreads.cc new file mode 100644 index 0000000..90020eb --- /dev/null +++ b/src/cpp/attention_broker/WorkerThreads.cc @@ -0,0 +1,86 @@ +#include + +#include "AttentionBrokerServer.h" +#include "attention_broker.grpc.pb.h" +#include "Utils.h" +#include "RequestSelector.h" +#include "WorkerThreads.h" +#include "HebbianNetworkUpdater.h" +#include "StimulusSpreader.h" + +using namespace attention_broker_server; +using namespace std; + +// -------------------------------------------------------------------------------- +// Public methods + +WorkerThreads::WorkerThreads(SharedQueue *stimulus, SharedQueue *correlation) { + stimulus_requests = stimulus; + correlation_requests = correlation; + threads_count = AttentionBrokerServer::WORKER_THREADS_COUNT; + for (unsigned int i = 0; i < threads_count; i++) { + threads.push_back(new thread( + &WorkerThreads::worker_thread, + this, + i, + stimulus_requests, + correlation_requests)); + } +} + +WorkerThreads::~WorkerThreads() { +} + +void WorkerThreads::graceful_stop() { + stop_flag_mutex.lock(); + stop_flag = true; + stop_flag_mutex.unlock(); + for (thread *worker_thread: threads) { + worker_thread->join(); + } +} + +// -------------------------------------------------------------------------------- +// Private methods + +void WorkerThreads::worker_thread( + unsigned int thread_id, + SharedQueue *stimulus_requests, + SharedQueue *correlation_requests) { + + RequestSelector *selector = RequestSelector::factory( + SelectorType::EVEN_THREAD_COUNT, + thread_id, + stimulus_requests, + correlation_requests); + HebbianNetworkUpdater *updater = HebbianNetworkUpdater::factory(HebbianNetworkUpdaterType::EXACT_COUNT); + StimulusSpreader *stimulus_spreader = StimulusSpreader::factory(StimulusSpreaderType::TOKEN); + pair request; + bool stop = false; + while (! stop) { + request = selector->next(); + if (request.second != NULL) { + switch (request.first) { + case RequestType::STIMULUS: { + stimulus_spreader->spread_stimuli((dasproto::HandleCount *) request.second); + break; + } + case RequestType::CORRELATION: { + updater->correlation((dasproto::HandleList *) request.second); + break; + } + default: { + Utils::error("Invalid request type: " + to_string((int) request.first)); + } + } + } else { + this_thread::sleep_for(chrono::milliseconds(100)); + stop_flag_mutex.lock(); + if (stop_flag) { + stop = true; + } + stop_flag_mutex.unlock(); + } + } + delete selector; +} diff --git a/src/cpp/attention_broker/WorkerThreads.h b/src/cpp/attention_broker/WorkerThreads.h new file mode 100644 index 0000000..38c429f --- /dev/null +++ b/src/cpp/attention_broker/WorkerThreads.h @@ -0,0 +1,66 @@ +#ifndef _ATTENTION_BROKER_SERVER_WORKERTHREADS_H +#define _ATTENTION_BROKER_SERVER_WORKERTHREADS_H + +#include +#include +#include +#include +#include + +#include "SharedQueue.h" + +using namespace std; +using namespace commons; + +namespace attention_broker_server { + +/** + * Used in AttentionBrokerServer to keep track of worker threads. + * + * WorkerThreads provides an abstraction to actual threads creation and shutdown. + */ +class WorkerThreads { + +public: + + /** + * Constructor. + * + * Start n worker threads (n is a parameter defined in AttentionBrokerServer) and + * keep then running getting requests from the queues which have been passed as + * parameters. + * + * Working threads can process any type of requests. The policy of which request + * queue a worker thread will read from next is determined by RequestSelector. + */ + WorkerThreads(SharedQueue *stimulus, SharedQueue *correlation); + ~WorkerThreads(); /// Destructor. + + /** + * Gracefully and synchronously stop all threads. + * + * Sets a flag which is check by each thread when the requests queue are empty. It means + * that both requests queues will be processed before the threads actually stop. When + * both requests queues are empty, threads return and are destroyed. This method will wait + * for all threads to finish before returning. + */ + void graceful_stop(); + +private: + + unsigned int threads_count; + vector threads; + bool stop_flag = false; + SharedQueue *stimulus_requests; + SharedQueue *correlation_requests; + mutex stop_flag_mutex; + + void worker_thread( + unsigned int thread_id, + SharedQueue *stimulus_requests, + SharedQueue *correlation_requests); +}; + +} // namespace attention_broker_server + +#endif // _ATTENTION_BROKER_SERVER_WORKERTHREADS_H diff --git a/src/cpp/hasher/BUILD b/src/cpp/hasher/BUILD new file mode 100644 index 0000000..e51113e --- /dev/null +++ b/src/cpp/hasher/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "hasher_lib", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + includes = ["."], + copts = [ + "-I/opt/3rd-party/mbedcrypto/include/" + ], + deps = [ + #"@mbedcrypto//:lib", + ] +) diff --git a/src/cpp/hasher/expression_hasher.cc b/src/cpp/hasher/expression_hasher.cc new file mode 100644 index 0000000..782d0c3 --- /dev/null +++ b/src/cpp/hasher/expression_hasher.cc @@ -0,0 +1,78 @@ +#include +#include +#include +#include "mbedtls/md5.h" +#include "expression_hasher.h" + +static unsigned char MD5_BUFFER[16]; +static char HASH[HANDLE_HASH_SIZE]; +static char HASHABLE_STRING[MAX_HASHABLE_STRING_SIZE]; + +char *compute_hash(char *input) { + mbedtls_md5_context context; + mbedtls_md5_init(&context); + mbedtls_md5_starts_ret(&context); + mbedtls_md5_update_ret(&context, (const unsigned char*) input, strlen(input)); + mbedtls_md5_finish_ret(&context, MD5_BUFFER); + mbedtls_md5_free(&context); + for (unsigned int i = 0; i < 16; i++) { + sprintf((char *) ((unsigned long) HASH + 2 * i), "%02x", MD5_BUFFER[i]); + } + HASH[32] = '\0'; + return strdup(HASH); +} + +char *named_type_hash(char *name) { + return compute_hash(name); +} + +char *terminal_hash(char *type, char *name) { + if (strlen(type) + strlen(name) >= MAX_HASHABLE_STRING_SIZE) { + fprintf(stderr, "Invalid (too large) terminal name"); + exit(1); + } + sprintf(HASHABLE_STRING, "%s%c%s", type, JOINING_CHAR, name); + return compute_hash(HASHABLE_STRING); +} + +char *composite_hash(char **elements, unsigned int nelements) { + + unsigned int total_size = 0; + unsigned int element_size[nelements]; + + for (unsigned int i = 0; i < nelements; i++) { + unsigned int size = strlen(elements[i]); + if (size > MAX_LITERAL_OR_SYMBOL_SIZE) { + fprintf(stderr, "Invalid (too large) composite elements"); + exit(1); + } + element_size[i] = size; + total_size += size; + } + if (total_size >= MAX_HASHABLE_STRING_SIZE) { + fprintf(stderr, "Invalid (too large) composite elements"); + exit(1); + } + + unsigned long cursor = 0; + for (unsigned int i = 0; i < nelements; i++) { + if (i == (nelements - 1)) { + strcpy((char *) (HASHABLE_STRING + cursor), elements[i]); + } else { + sprintf((char *) (HASHABLE_STRING + cursor), "%s%c", elements[i], JOINING_CHAR); + cursor += 1; + } + cursor += element_size[i]; + } + + return compute_hash(HASHABLE_STRING); +} + +char *expression_hash(char *type_hash, char **elements, unsigned int nelements) { + char *composite[nelements + 1]; + composite[0] = type_hash; + for (unsigned int i = 0; i < nelements; i++) { + composite[i + 1] = elements[i]; + } + return composite_hash(composite, nelements + 1); +} diff --git a/src/cpp/hasher/expression_hasher.h b/src/cpp/hasher/expression_hasher.h new file mode 100644 index 0000000..78d309f --- /dev/null +++ b/src/cpp/hasher/expression_hasher.h @@ -0,0 +1,15 @@ +#ifndef EXPRESSIONHASHER_H +#define EXPRESSIONHASHER_H + +#define JOINING_CHAR ((char ) ' ') +#define MAX_LITERAL_OR_SYMBOL_SIZE ((size_t) 10000) +#define MAX_HASHABLE_STRING_SIZE ((size_t) 100000) +#define HANDLE_HASH_SIZE ((unsigned int) 33) + +char *compute_hash(char *input); +char *named_type_hash(char *name); +char *terminal_hash(char *type, char *name); +char *expression_hash(char *type_hash, char **elements, unsigned int nelements); +char *composite_hash(char **elements, unsigned int nelements); + +#endif diff --git a/src/cpp/main/BUILD b/src/cpp/main/BUILD new file mode 100644 index 0000000..20ba8e0 --- /dev/null +++ b/src/cpp/main/BUILD @@ -0,0 +1,49 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "attention_broker_main_lib", + srcs = ["attention_broker_main.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//cpp/attention_broker:attention_broker_server_lib", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + ], +) + +cc_library( + name = "query_engine_main_lib", + srcs = ["query_engine_main.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//cpp/query_engine:query_engine_lib", + ], +) + +cc_library( + name = "query_client_main_lib", + srcs = ["query_client_main.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//cpp/query_engine:query_engine_lib", + ], +) + +cc_library( + name = "link_creation_engine_main_lib", + srcs = ["link_creation_engine_main.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//cpp/query_engine:query_engine_lib", + ], +) + +cc_library( + name = "word_query_main_lib", + srcs = ["word_query_main.cc"], + hdrs = glob(["*.h"]), + deps = [ + "//cpp/query_engine:query_engine_lib", + ], +) diff --git a/src/cpp/main/attention_broker_main.cc b/src/cpp/main/attention_broker_main.cc new file mode 100644 index 0000000..a85bd84 --- /dev/null +++ b/src/cpp/main/attention_broker_main.cc @@ -0,0 +1,50 @@ +#include +#include + +#include +#include +#include + +#include + +#include "common.pb.h" +#include "attention_broker.grpc.pb.h" +#include "attention_broker.pb.h" + +#include "AttentionBrokerServer.h" + +//attention_broker_server::AttentionBrokerServer service; + +/* +void ctrl_c_handler(int) { + std::cout << "Stopping AttentionBrokerServer..." << std::endl; + service.graceful_shutdown(); + std::cout << "Done." << std::endl; + exit(0); +} +*/ + +void run_server(unsigned int port) { + attention_broker_server::AttentionBrokerServer service; + std::string server_address = "localhost:" + to_string(port); + //grpc::EnableDefaultHealthCheckService(true); + //grpc::reflection::InitProtoReflectionServerBuilderPlugin(); + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "AttentionBroker server listening on " << server_address << std::endl; + server->Wait(); +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + cerr << "Attention broker" << endl; + cerr << "Usage: " << argv[0] << " PORT" << endl; + exit(1); + } + unsigned int port = stoi(argv[1]); + //signal(SIGINT, &ctrl_c_handler); + run_server(port); + return 0; +} diff --git a/src/cpp/main/link_creation_engine_main.cc b/src/cpp/main/link_creation_engine_main.cc new file mode 100644 index 0000000..5638695 --- /dev/null +++ b/src/cpp/main/link_creation_engine_main.cc @@ -0,0 +1,335 @@ +#include +#include +#include + +#include + +#include "DASNode.h" +#include "RemoteIterator.h" +#include "QueryAnswer.h" +#include "AtomDBSingleton.h" +#include "AtomDB.h" +#include "Utils.h" + +#define MAX_QUERY_ANSWERS ((unsigned int) 1000) + +using namespace std; + +void ctrl_c_handler(int) { + std::cout << "Stopping link creation engine server..." << std::endl; + std::cout << "Done." << std::endl; + exit(0); +} + +std::vector split(string s, string delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + + return tokens; +} + +double compute_sim1(const vector &tokens1, const vector &tokens2) { + + unsigned int count = 0; + + /* + for (unsigned int i = 0; i < tokens1.size(); i++) { + for (unsigned int j = 0; j < tokens2.size(); j++) { + count++; + break; + } + } + + for (unsigned int i = 0; i < tokens2.size(); i++) { + for (unsigned int j = 0; j < tokens2.size(); j++) { + count++; + break; + } + } + */ + + for (auto token1: tokens1) { + for (auto token2: tokens2) { + if (token1 == token2) { + count++; + break; + } + } + } + + for (auto token2: tokens2) { + for (auto token1: tokens1) { + if (token2 == token1) { + count++; + break; + } + } + } + + return ((1.0) * count) / (tokens1.size() + tokens2.size()); +} + +double compute_sim2(const vector &tokens1, const vector &tokens2) { + + if (tokens1.size() != tokens2.size()) { + return 0.0; + } + unsigned int count = 0; + unsigned int total_length = 0; + for (unsigned int i = 0; i < tokens1.size(); i++) { + for (unsigned int j = 0; j < tokens1[i].length(); j++) { + if (tokens1[i][j] == tokens2[i][j]) { + count++; + } + } + total_length += tokens1[i].length(); + } + return (1.0 * count) / total_length; +} + +string highlight(const vector &tokens1, const vector &tokens2, const set &highlighted) { + //printf("\033[31;1;4mHello\033[0m"); + string answer = ""; + bool token_flag, char_flag, word_flag; + for (unsigned int i = 0; i < tokens1.size(); i++) { + token_flag = (highlighted.find(tokens1[i]) != highlighted.end()); + word_flag = false; + if (highlighted.size() == 0) { + for (auto token: tokens2) { + if (tokens1[i] == token) { + word_flag = true; + break; + } + } + } + for (unsigned int j = 0; j < tokens1[i].length(); j++) { + if (tokens1.size() == tokens2.size()) { + char_flag = (tokens1[i][j] == tokens2[i][j]); + } else { + char_flag = false; + } + char_flag = false; // XXXXX + if (token_flag || char_flag || word_flag) { + answer += "\033["; + if (token_flag) { + answer += "1;4"; + if (char_flag || word_flag) { + answer += ";"; + } + } + if (char_flag) { + answer += "4"; + if (word_flag) { + answer += ";"; + } + } + if (word_flag) { + answer += "7"; + } + answer += "m"; + answer += tokens1[i][j]; + answer += "\033[0m"; + } else { + answer += tokens1[i][j]; + } + } + if (i != (tokens1.size() - 1)) { + answer += " "; + } + } + return answer; +} + +void build_link(const string &link_type_tag, const string str1, const string str2, double threshold, stack &output, const set &highlighted) { + + string sentence1 = str1.substr(1, str1.size() - 2); + string sentence2 = str2.substr(1, str2.size() - 2); + + vector tokens1 = split(sentence1, " "); + vector tokens2 = split(sentence2, " "); + + double v1 = compute_sim1(tokens1, tokens2); + //double v2 = compute_sim2(tokens1, tokens2); + + if (v1 >= threshold) { + //output.push(std::to_string(v1) + ": " + highlight(tokens1, tokens2, highlighted)); + //output.push(std::to_string(v2) + ": " + highlight(tokens2, tokens1, highlighted)); + output.push(highlight(tokens1, tokens2, highlighted)); + output.push(highlight(tokens2, tokens1, highlighted)); + } +} + +string handle_to_atom(const char *handle) { + + shared_ptr db = AtomDBSingleton::get_instance(); + shared_ptr document = db->get_atom_document(handle); + shared_ptr targets = db->query_for_targets((char *) handle); + string answer; + + if (targets != NULL) { + // is link + answer += "<"; + answer += document->get("named_type"); + answer += ": ["; + for (unsigned int i = 0; i < targets->size(); i++) { + answer += handle_to_atom(targets->get_handle(i)); + if (i < (targets->size() - 1)) { + answer += ", "; + } + } + answer += ">"; + } else { + // is node + answer += "("; + answer += document->get("named_type"); + answer += ": "; + answer += document->get("name"); + answer += ")"; + } + + return answer; +} + +void run( + const string &context, + const string &link_type_tag, + const set highlighted) { + + string server_id = "localhost:31700"; + string client_id = "localhost:31701"; + + AtomDBSingleton::init(); + shared_ptr db = AtomDBSingleton::get_instance(); + + string and_operator = "AND"; + string or_operator = "OR"; + string link_template = "LINK_TEMPLATE"; + string link = "LINK"; + string node = "NODE"; + string variable = "VARIABLE"; + string expression = "Expression"; + string symbol = "Symbol"; + string sentence = "Sentence"; + string word = "Word"; + string similarity = "Similarity"; + string contains = "Contains"; + string sentence1 = "sentence1"; + string sentence2 = "sentence2"; + string word1 = "word1"; + string word2 = "word2"; + string tv1 = "tv1"; + + // (Contains (Sentence "aef cbe dfb fbe eca eff bad") (Word "eff")) + + vector query_same_word = { + and_operator, "2", + link_template, expression, "3", + node, symbol, contains, + variable, sentence1, + variable, word1, + link_template, expression, "3", + node, symbol, contains, + variable, sentence2, + variable, word1 + }; + + vector query_same_size { + or_operator, "1", + link_template, expression, "4", + node, symbol, similarity, + variable, sentence1, + variable, sentence2, + variable, tv1 + }; + + DASNode client(client_id, server_id); + QueryAnswer *query_answer; + unsigned int count = 0; + RemoteIterator *response; + + if (link_type_tag == "LINK1") { + response = client.pattern_matcher_query(query_same_word, context); + } else if (link_type_tag == "LINK2") { + response = client.pattern_matcher_query(query_same_size, context, true); + } else { + Utils::error("Invalid link_type_tag: " + link_type_tag); + } + + shared_ptr sentence_document1; + shared_ptr sentence_document2; + shared_ptr sentence_symbol_document1; + shared_ptr sentence_symbol_document2; + stack output; + set already_inserted_links; + while (! response->finished()) { + if ((query_answer = response->pop()) == NULL) { + Utils::sleep(); + } else { + if (! strcmp(query_answer->assignment.get(sentence1.c_str()), query_answer->assignment.get(sentence2.c_str()))) { + continue; + } + //cout << query_answer->to_string() << endl; + //cout << handle_to_atom(query_answer->handles[0]) << endl; + //cout << handle_to_atom(query_answer->handles[1]) << endl; + sentence_document1 = db->get_atom_document(query_answer->assignment.get(sentence1.c_str())); + sentence_document2 = db->get_atom_document(query_answer->assignment.get(sentence2.c_str())); + sentence_symbol_document1 = db->get_atom_document(sentence_document1->get("targets", 1)); + sentence_symbol_document2 = db->get_atom_document(sentence_document2->get("targets", 1)); + string s1 = string(sentence_symbol_document1->get("name")); + string s2 = string(sentence_symbol_document2->get("name")); + if ((already_inserted_links.find(s1 + s2) == already_inserted_links.end()) && + (already_inserted_links.find(s2 + s1) == already_inserted_links.end())) { + build_link(link_type_tag, s1, s2, 0.0, output, highlighted); + already_inserted_links.insert(s1 + s2); + already_inserted_links.insert(s2 + s1); + } + + if (++count == MAX_QUERY_ANSWERS) { + break; + } + } + } + if (count == 0) { + cout << "No match for query" << endl; + } else { + while (! output.empty()) { + cout << output.top() << endl; + output.pop(); + cout << output.top() << endl; + output.pop(); + cout << endl; + } + } + + delete response; +} + +int main(int argc, char* argv[]) { + + if (argc < 3) { + cerr << "Usage: " << argv[0] << " *" << endl; + exit(1); + } + signal(SIGINT, &ctrl_c_handler); + string context = argv[1]; + string link_type_tag = argv[2]; + + if ((link_type_tag != "LINK1") && (link_type_tag != "LINK2")) { + Utils::error("Invalid link_type_tag: " + link_type_tag); + } + + set highlighted; + for (int i = 3; i < argc; i++) { + highlighted.insert(string(argv[i])); + } + + run(context, link_type_tag, highlighted); + return 0; +} diff --git a/src/cpp/main/query_client_main.cc b/src/cpp/main/query_client_main.cc new file mode 100644 index 0000000..878d1fc --- /dev/null +++ b/src/cpp/main/query_client_main.cc @@ -0,0 +1,58 @@ +#include +#include + +#include + +#include "DASNode.h" +#include "RemoteIterator.h" +#include "QueryAnswer.h" +#include "AtomDBSingleton.h" +#include "Utils.h" + +#define MAX_QUERY_ANSWERS ((unsigned int) 1000) + +using namespace std; + +void ctrl_c_handler(int) { + std::cout << "Stopping query engine server..." << std::endl; + std::cout << "Done." << std::endl; + exit(0); +} + +int main(int argc, char* argv[]) { + + if (argc < 4) { + cerr << "Usage: " << argv[0] << " CLIENT_HOST:CLIENT_PORT SERVER_HOST:SERVER_PORT QUERY_TOKEN+ (hosts are supposed to be public IPs or known hostnames)" << endl; + exit(1); + } + + string client_id = string(argv[1]); + string server_id = string(argv[2]); + + signal(SIGINT, &ctrl_c_handler); + vector query; + for (int i = 3; i < argc; i++) { + query.push_back(argv[i]); + } + + DASNode client(client_id, server_id); + QueryAnswer *query_answer; + unsigned int count = 0; + RemoteIterator *response = client.pattern_matcher_query(query); + while (! response->finished()) { + if ((query_answer = response->pop()) == NULL) { + Utils::sleep(); + } else { + cout << query_answer->to_string() << endl; + if (++count == MAX_QUERY_ANSWERS) { + break; + } + } + } + if (count == 0) { + cout << "No match for query" << endl; + } + + delete response; + return 0; +} diff --git a/src/cpp/main/query_engine_main.cc b/src/cpp/main/query_engine_main.cc new file mode 100644 index 0000000..c60f71c --- /dev/null +++ b/src/cpp/main/query_engine_main.cc @@ -0,0 +1,35 @@ +#include +#include + +#include + +#include "AtomDBSingleton.h" +#include "Utils.h" +#include "DASNode.h" + +using namespace std; + +void ctrl_c_handler(int) { + //std::cout << "Stopping query engine server..." << std::endl; + std::cout << "Cleaning GRPC buffers..." << std::endl; + std::cout << "Done." << std::endl; + exit(0); +} + +int main(int argc, char* argv[]) { + + if (argc != 2) { + cerr << "Usage: " << argv[0] << "" << endl; + exit(1); + } + + string server_id = "localhost:" + string(argv[1]); + signal(SIGINT, &ctrl_c_handler); + AtomDBSingleton::init(); + DASNode server(server_id); + cout << "############################# REQUEST QUEUE EMPTY ##################################" << endl; + do { + Utils::sleep(1000); + } while (true); + return 0; +} diff --git a/src/cpp/main/word_query_main.cc b/src/cpp/main/word_query_main.cc new file mode 100644 index 0000000..dd791ac --- /dev/null +++ b/src/cpp/main/word_query_main.cc @@ -0,0 +1,180 @@ +#include +#include + +#include + +#include "DASNode.h" +#include "RemoteIterator.h" +#include "QueryAnswer.h" +#include "AtomDBSingleton.h" +#include "AtomDB.h" +#include "Utils.h" + +#define MAX_QUERY_ANSWERS ((unsigned int) 500) + +using namespace std; + +void ctrl_c_handler(int) { + std::cout << "Stopping link creation engine server..." << std::endl; + std::cout << "Done." << std::endl; + exit(0); +} + +std::vector split(string s, string delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + + return tokens; +} + +string highlight(const string &s, const set &highlighted) { + vector tokens = split(s.substr(1, s.size() - 2), " "); + string answer = ""; + for (unsigned int i = 0; i < tokens.size(); i++) { + if (highlighted.find(tokens[i]) != highlighted.end()) { + //"\033[31;1;4mHello\033[0m" + answer += "\033[1;4m" + tokens[i] + "\033[0m"; + } else { + answer += tokens[i]; + } + if (i != (tokens.size() - 1)) { + answer += " "; + } + } + return answer; +} + + + +string handle_to_atom(const char *handle) { + + shared_ptr db = AtomDBSingleton::get_instance(); + shared_ptr document = db->get_atom_document(handle); + shared_ptr targets = db->query_for_targets((char *) handle); + string answer; + + if (targets != NULL) { + // is link + answer += "<"; + answer += document->get("named_type"); + answer += ": ["; + for (unsigned int i = 0; i < targets->size(); i++) { + answer += handle_to_atom(targets->get_handle(i)); + if (i < (targets->size() - 1)) { + answer += ", "; + } + } + answer += ">"; + } else { + // is node + answer += "("; + answer += document->get("named_type"); + answer += ": "; + answer += document->get("name"); + answer += ")"; + } + + return answer; +} + +void run( + const string &context, + const string &word_tag) { + + string server_id = "localhost:31700"; + string client_id = "localhost:31701"; + + AtomDBSingleton::init(); + shared_ptr db = AtomDBSingleton::get_instance(); + + string and_operator = "AND"; + string link_template = "LINK_TEMPLATE"; + string link = "LINK"; + string node = "NODE"; + string variable = "VARIABLE"; + string expression = "Expression"; + string symbol = "Symbol"; + string sentence = "Sentence"; + string word = "Word"; + string contains = "Contains"; + string sentence1 = "sentence1"; + string sentence2 = "sentence2"; + string word1 = "word1"; + string word2 = "word2"; + + vector query_word = { + link_template, expression, "3", + node, symbol, contains, + variable, sentence1, + link, expression, "2", + node, symbol, word, + node, symbol, "\"" + word_tag + "\"" + }; + + DASNode client(client_id, server_id); + QueryAnswer *query_answer; + unsigned int count = 0; + RemoteIterator *response = client.pattern_matcher_query(query_word, context, true); + shared_ptr sentence_document; + shared_ptr sentence_name_document; + vector sentences; + while (! response->finished()) { + if ((query_answer = response->pop()) == NULL) { + Utils::sleep(); + } else { + //cout << "------------------------------------------" << endl; + //cout << query_answer->to_string() << endl; + const char *handle; + handle = query_answer->assignment.get(sentence1.c_str()); + //cout << string(handle) << endl; + //cout << handle_to_atom(handle) << endl; + sentence_document = db->get_atom_document(handle); + handle = sentence_document->get("targets", 1); + //cout << string(handle) << endl; + //cout << handle_to_atom(handle) << endl; + sentence_name_document = db->get_atom_document(handle); + // cout << string(sentence_name_document->get("name")) << endl; + set to_highlight; + to_highlight.insert(word_tag); + string sentence_name = string(sentence_name_document->get("name")); + string highlighted_sentence_name = highlight(sentence_name, to_highlight); + string w = "\"" + word_tag + "\""; + string line = "(Contains (Sentence " + + highlighted_sentence_name + + ") (Word \"" + + highlight(w, to_highlight) + + "\"))"; + cout << line << endl; + if (++count == MAX_QUERY_ANSWERS) { + break; + } + } + } + if (count == 0) { + cout << "No match for query" << endl; + exit(0); + } + + delete response; +} + +int main(int argc, char* argv[]) { + + if (argc < 3) { + cerr << "Usage: " << argv[0] << " " << endl; + exit(1); + } + signal(SIGINT, &ctrl_c_handler); + string context = argv[1]; + string word_tag = argv[2]; + + run(context, word_tag); + return 0; +} diff --git a/src/cpp/query_engine/AtomDB.cc b/src/cpp/query_engine/AtomDB.cc new file mode 100644 index 0000000..cd62f9b --- /dev/null +++ b/src/cpp/query_engine/AtomDB.cc @@ -0,0 +1,175 @@ +#include +#include +#include +#include +#include "AtomDB.h" +#include "Utils.h" + +#include "AttentionBrokerServer.h" +#include "attention_broker.grpc.pb.h" +#include +#include "attention_broker.pb.h" + +using namespace query_engine; +using namespace commons; + +string AtomDB::WILDCARD; +string AtomDB::REDIS_PATTERNS_PREFIX; +string AtomDB::REDIS_TARGETS_PREFIX; +string AtomDB::MONGODB_DB_NAME; +string AtomDB::MONGODB_COLLECTION_NAME; +string AtomDB::MONGODB_FIELD_NAME[MONGODB_FIELD::size]; + +AtomDB::AtomDB() { + redis_setup(); + mongodb_setup(); + attention_broker_setup(); +} + +AtomDB::~AtomDB() { + if (this->redis_cluster != NULL) { + redisClusterFree(this->redis_cluster); + } + if (this->redis_single != NULL) { + redisFree(this->redis_single); + } + delete this->mongodb_pool; + // delete this->mongodb_client; +} + +void AtomDB::attention_broker_setup() { + + grpc::ClientContext context; + grpc::Status status; + dasproto::Empty empty; + dasproto::Ack ack; + string attention_broker_address = "localhost:37007"; + + auto stub = dasproto::AttentionBroker::NewStub(grpc::CreateChannel( + attention_broker_address, + grpc::InsecureChannelCredentials())); + status = stub->ping(&context, empty, &ack); + if (status.ok()) { + std::cout << "Connected to AttentionBroker at " << attention_broker_address << endl; + } else { + Utils::error("Couldn't connect to AttentionBroker at " + attention_broker_address); + } + if (ack.msg() != "PING") { + Utils::error("Invalid AttentionBroker answer for PING"); + } +} + +void AtomDB::redis_setup() { + + string host = Utils::get_environment("DAS_REDIS_HOSTNAME"); + string port = Utils::get_environment("DAS_REDIS_PORT"); + string address = host + ":" + port; + string cluster = Utils::get_environment("DAS_USE_REDIS_CLUSTER"); + std::transform(cluster.begin(), cluster.end(), cluster.begin(), ::toupper); + this->cluster_flag = (cluster == "TRUE"); + + if (host == "" || port == "") { + Utils::error("You need to set Redis access info as environment variables: DAS_REDIS_HOSTNAME, DAS_REDIS_PORT and DAS_USE_REDIS_CLUSTER"); + } + string cluster_tag = (this->cluster_flag ? "CLUSTER" : "NON-CLUSTER"); + + if (this->cluster_flag) { + this->redis_cluster = redisClusterConnect(address.c_str(), 0); + this->redis_single = NULL; + } else { + this->redis_single = redisConnect(host.c_str(), stoi(port)); + this->redis_cluster = NULL; + } + + if (this->redis_cluster == NULL && this->redis_single == NULL) { + Utils::error("Connection error."); + } else if ((! this->cluster_flag) && this->redis_single->err) { + Utils::error("Redis error: " + string(this->redis_single->errstr)); + } else if (this->cluster_flag && this->redis_cluster->err) { + Utils::error("Redis cluster error: " + string(this->redis_cluster->errstr)); + } else { + cout << "Connected to (" << cluster_tag << ") Redis at " << address << endl; + } +} + +mongocxx::database AtomDB::get_database(){ + auto database = this->mongodb_pool->acquire(); + return database[MONGODB_DB_NAME]; +} + +void AtomDB::mongodb_setup() { + + string host = Utils::get_environment("DAS_MONGODB_HOSTNAME"); + string port = Utils::get_environment("DAS_MONGODB_PORT"); + string user = Utils::get_environment("DAS_MONGODB_USERNAME"); + string password = Utils::get_environment("DAS_MONGODB_PASSWORD"); + if (host == "" || port == "" || user == "" || password == "") { + Utils::error(string("You need to set MongoDB access info as environment variables: ") + \ + "DAS_MONGODB_HOSTNAME, DAS_MONGODB_PORT, DAS_MONGODB_USERNAME and DAS_MONGODB_PASSWORD"); + } + string address = host + ":" + port; + string url = "mongodb://" + user + ":" + password + "@" + address; + + try { + mongocxx::instance instance; + auto uri = mongocxx::uri{url}; + this->mongodb_pool = new mongocxx::pool(uri); + this->mongodb = get_database(); + + // this->mongodb_client = new mongocxx::client(uri); + // this->mongodb = (*this->mongodb_client)[MONGODB_DB_NAME]; + const auto ping_cmd = bsoncxx::builder::basic::make_document(bsoncxx::builder::basic::kvp("ping", 1)); + this->mongodb.run_command(ping_cmd.view()); + this->mongodb_collection = this->mongodb[MONGODB_COLLECTION_NAME]; + //auto atom_count = this->mongodb_collection.count_documents({}); + //std::cout << "Connected to MongoDB at " << address << " Atom count: " << atom_count << endl; + std::cout << "Connected to MongoDB at " << address << endl; + } catch (const std::exception& e) { + Utils::error(e.what()); + } +} + +shared_ptr AtomDB::query_for_pattern(shared_ptr pattern_handle) { + redisReply *reply = (redisReply *) redisCommand(this->redis_single, "SMEMBERS %s:%s", REDIS_PATTERNS_PREFIX.c_str(), pattern_handle.get()); + if (reply == NULL) { + Utils::error("Redis error"); + } + if (reply->type != REDIS_REPLY_SET && reply->type != REDIS_REPLY_ARRAY) { + Utils::error("Invalid Redis response: " + std::to_string(reply->type)); + } + // NOTE: Intentionally, we aren't destroying 'reply' objects.'reply' objects are destroyed in ~RedisSet(). + return shared_ptr(new atomdb_api_types::RedisSet(reply)); +} + +shared_ptr AtomDB::query_for_targets(shared_ptr link_handle) { + return query_for_targets(link_handle.get()); +} + +shared_ptr AtomDB::query_for_targets(char *link_handle_ptr) { + redisReply *reply = (redisReply *) redisCommand(this->redis_single, "GET %s:%s", REDIS_TARGETS_PREFIX.c_str(), link_handle_ptr); + /* + if (reply == NULL) { + Utils::error("Redis error"); + } + */ + if ((reply == NULL) || (reply->type == REDIS_REPLY_NIL)) { + return shared_ptr(NULL); + } + if (reply->type != REDIS_REPLY_STRING) { + Utils::error("Invalid Redis response: " + std::to_string(reply->type) + + " != " + std::to_string(REDIS_REPLY_STRING)); + } + // NOTE: Intentionally, we aren't destroying 'reply' objects.'reply' objects are destroyed in ~RedisSet(). + return shared_ptr(new atomdb_api_types::RedisStringBundle(reply)); +} + +shared_ptr AtomDB::get_atom_document(const char *handle) { + this->mongodb_mutex.lock(); + auto mongodb_collection = get_database()[MONGODB_COLLECTION_NAME]; + auto reply = mongodb_collection.find_one( + bsoncxx::v_noabi::builder::basic::make_document( + bsoncxx::v_noabi::builder::basic::kvp(MONGODB_FIELD_NAME[MONGODB_FIELD::ID], handle))); + //cout << bsoncxx::to_json(*reply) << endl; // Note to reviewer: please let this dead code here + this->mongodb_mutex.unlock(); + return shared_ptr(new atomdb_api_types::MongodbDocument(reply)); +} diff --git a/src/cpp/query_engine/AtomDB.h b/src/cpp/query_engine/AtomDB.h new file mode 100644 index 0000000..9a62b72 --- /dev/null +++ b/src/cpp/query_engine/AtomDB.h @@ -0,0 +1,87 @@ +#ifndef _QUERY_ENGINE_ATOMDB_H +#define _QUERY_ENGINE_ATOMDB_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "AtomDBAPITypes.h" + +using namespace std; + +namespace query_engine { + + +enum MONGODB_FIELD { + ID = 0, + size +}; + +// ------------------------------------------------------------------------------------------------- +// NOTE TO REVIEWER: +// +// This class will be replaced/integrated by/with classes already implemented in das-atom-db. +// +// However, that classes will need to be revisited in order to allow the methods implemented here +// because although the design of such methods is nasty, they have the string advantage of +// allowing the reuse of structures allocated by the DBMS (Redis an MongoDB) withpout the need +// of re-allocation of other dataclasses. Although this nasty behavior may not be desirable +// outside the DAS bounds, it's quite appealing inside the query engine (and perhaps other +// parts of internal stuff). +// +// I think it's pointless to make any further documentation while we don't make this integrfation. +// ------------------------------------------------------------------------------------------------- + +class AtomDB { + +public: + + AtomDB(); + ~AtomDB(); + + static string WILDCARD; + static string REDIS_PATTERNS_PREFIX; + static string REDIS_TARGETS_PREFIX; + static string MONGODB_DB_NAME; + static string MONGODB_COLLECTION_NAME; + static string MONGODB_FIELD_NAME[MONGODB_FIELD::size]; + + static void initialize_statics() { + WILDCARD = "*"; + REDIS_PATTERNS_PREFIX = "patterns"; + REDIS_TARGETS_PREFIX = "outgoing_set"; + MONGODB_DB_NAME = "das"; + MONGODB_COLLECTION_NAME = "atoms"; + MONGODB_FIELD_NAME[MONGODB_FIELD::ID] = "_id"; + } + + shared_ptr query_for_pattern(shared_ptr pattern_handle); + shared_ptr query_for_targets(shared_ptr link_handle); + shared_ptr query_for_targets(char *link_handle_ptr); + shared_ptr get_atom_document(const char *handle); + +private: + + bool cluster_flag; + redisClusterContext *redis_cluster; + redisContext *redis_single; + mongocxx::client *mongodb_client; + mongocxx::database mongodb; + mongocxx::v_noabi::collection mongodb_collection; + mutex mongodb_mutex; + mongocxx::pool *mongodb_pool; + + mongocxx::database get_database(); + + void redis_setup(); + void mongodb_setup(); + void attention_broker_setup(); +}; + +} // namespace query_engine + +#endif // _QUERY_ENGINE_ATOMDB_H diff --git a/src/cpp/query_engine/AtomDBAPITypes.cc b/src/cpp/query_engine/AtomDBAPITypes.cc new file mode 100644 index 0000000..2d1a91b --- /dev/null +++ b/src/cpp/query_engine/AtomDBAPITypes.cc @@ -0,0 +1,92 @@ +#include + +#include "Utils.h" +#include "expression_hasher.h" +#include "AtomDBAPITypes.h" + +using namespace query_engine; +using namespace atomdb_api_types; +using namespace commons; + +RedisSet::RedisSet(redisReply *reply) : HandleList() { + this->redis_reply = reply; + this->handles_size = reply->elements; + this->handles = new char *[this->handles_size]; + for (unsigned int i = 0; i < this->handles_size; i++) { + handles[i] = reply->element[i]->str; + } +} + +RedisSet::~RedisSet() { + delete [] this->handles; + freeReplyObject(this->redis_reply); +} + +const char *RedisSet::get_handle(unsigned int index) { + if (index > this->handles_size) { + Utils::error("Handle index out of bounds: " + to_string(index) + " Answer array size: " + to_string(this->handles_size)); + } + return handles[index]; +} + +unsigned int RedisSet::size() { + return this->handles_size; +} + +RedisStringBundle::RedisStringBundle(redisReply *reply) : HandleList() { + unsigned int handle_length = (HANDLE_HASH_SIZE - 1); + this->redis_reply = reply; + this->handles_size = reply->len / handle_length; + this->handles = new char *[this->handles_size]; + for (unsigned int i = 0; i < this->handles_size; i++) { + handles[i] = strndup(reply->str + (i * handle_length), handle_length); + } +} + +RedisStringBundle::~RedisStringBundle() { + for (unsigned int i = 0; i < this->handles_size; i++) { + free(this->handles[i]); + } + delete [] this->handles; + freeReplyObject(this->redis_reply); +} + +const char *RedisStringBundle::get_handle(unsigned int index) { + if (index > this->handles_size) { + Utils::error("Handle index out of bounds: " + to_string(index) + " Answer handles size: " + to_string(this->handles_size)); + } + return handles[index]; +} + +unsigned int RedisStringBundle::size() { + return this->handles_size; +} + +MongodbDocument::MongodbDocument(core::v1::optional& document) { + this->document = document; +} + +MongodbDocument::~MongodbDocument() { +} + +const char *MongodbDocument::get(const string &key) { + // Note for reference: .to_string() instead of .data() would return a std::string + return ((*this->document)[key]).get_string().value.data(); +} + +const char *MongodbDocument::get(const string &array_key, unsigned int index) { + // Note for reference: .to_string() instead of .data() would return a std::string + return ((*this->document)[array_key]).get_array().value[index].get_string().value.data(); +} + +unsigned int MongodbDocument::get_size(const string &array_key) { + // NOTE TO REVIEWER + // TODO: this implementation is wrong and need to be fixed before integration in das-atom-db + // I couldn't figure out a way to discover the number of elements in a BSON array. + //cout << "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" << endl; + //cout << "XXXXXXXXXXXXXXXXXXX MongodbDocument::get_size()" << endl; + //cout << "XXXXXXXXXXXXXXXXXXX MongodbDocument::get_size() length: " << ((*this->document)[array_key]).get_array().value.length() << endl; + //cout << "XXXXXXXXXXXXXXXXXXX MongodbDocument::get_size() HASH: " << HANDLE_HASH_SIZE << endl; + //cout << "XXXXXXXXXXXXXXXXXXX MongodbDocument::get_size() size: " << ((*this->document)[array_key]).get_array().value.length() / HANDLE_HASH_SIZE << endl; + return ((*this->document)[array_key]).get_array().value.length() / HANDLE_HASH_SIZE; +} diff --git a/src/cpp/query_engine/AtomDBAPITypes.h b/src/cpp/query_engine/AtomDBAPITypes.h new file mode 100644 index 0000000..3679508 --- /dev/null +++ b/src/cpp/query_engine/AtomDBAPITypes.h @@ -0,0 +1,112 @@ +#ifndef _QUERY_ENGINE_ATOMDBAPITYPES_H +#define _QUERY_ENGINE_ATOMDBAPITYPES_H + +#include +#include +#include +#include +#include +#include +#include +#include "Utils.h" + +using namespace std; +using namespace commons; + +namespace query_engine { +namespace atomdb_api_types { + + +// ------------------------------------------------------------------------------------------------- +// NOTE TO REVIEWER: +// +// This class will be replaced/integrated by/with classes already implemented in das-atom-db. +// +// However, that classes will need to be revisited in order to allow the methods implemented here +// because although the design of such methods is nasty, they have the string advantage of +// allowing the reuse of structures allocated by the DBMS (Redis an MongoDB) withpout the need +// of re-allocation of other dataclasses. Although this nasty behavior may not be desirable +// outside the DAS bounds, it's quite appealing inside the query engine (and perhaps other +// parts of internal stuff). +// +// I think it's pointless to make any further documentation while we don't make this integrfation. +// ------------------------------------------------------------------------------------------------- + + +class HandleList { + +public: + + HandleList() {} + virtual ~HandleList() {} + + virtual const char *get_handle(unsigned int index) = 0; + virtual unsigned int size() = 0; +}; + +class RedisSet : public HandleList { + +public: + + RedisSet(redisReply *reply); + ~RedisSet(); + + const char *get_handle(unsigned int index); + unsigned int size(); + +private: + + unsigned int handles_size; + char **handles; + redisReply *redis_reply; +}; + +class RedisStringBundle : public HandleList { + +public: + + RedisStringBundle(redisReply *reply); + ~RedisStringBundle(); + + const char *get_handle(unsigned int index); + unsigned int size(); + +private: + + unsigned int handles_size; + char **handles; + redisReply *redis_reply; +}; + +class AtomDocument { + +public: + + AtomDocument() {} + virtual ~AtomDocument() {} + + virtual const char *get(const string &key) = 0; + virtual const char *get(const string &array_key, unsigned int index) = 0; + virtual unsigned int get_size(const string &array_key) = 0; +}; + +class MongodbDocument : public AtomDocument { + +public: + + MongodbDocument(core::v1::optional& document); + ~MongodbDocument(); + + const char *get(const string &key); + virtual const char *get(const string &array_key, unsigned int index); + virtual unsigned int get_size(const string &array_key); + +private: + + core::v1::optional document; +}; + +} // namespace atomdb_api_types +} // namespace query_engine + +#endif // _QUERY_ENGINE_ATOMDBAPITYPES_H diff --git a/src/cpp/query_engine/AtomDBSingleton.cc b/src/cpp/query_engine/AtomDBSingleton.cc new file mode 100644 index 0000000..a037923 --- /dev/null +++ b/src/cpp/query_engine/AtomDBSingleton.cc @@ -0,0 +1,30 @@ +#include "AtomDBSingleton.h" +#include "Utils.h" + +using namespace query_engine; +using namespace commons; + +bool AtomDBSingleton::initialized = false; +shared_ptr AtomDBSingleton::atom_db = shared_ptr{}; + +// -------------------------------------------------------------------------------- +// Public methods + +void AtomDBSingleton::init() { + if (initialized) { + Utils::error("AtomDBSingleton already initialized. AtomDBSingleton::init() should be called only once."); + } else { + AtomDB::initialize_statics(); + atom_db = shared_ptr(new AtomDB()); + initialized = true; + } +} + +shared_ptr AtomDBSingleton::get_instance() { + if (! initialized) { + Utils::error("Uninitialized AtomDBSingleton. AtomDBSingleton::init() must be called before AtomDBSingleton::get_instance()"); + return shared_ptr{}; // To avoid warnings + } else { + return atom_db; + } +} diff --git a/src/cpp/query_engine/AtomDBSingleton.h b/src/cpp/query_engine/AtomDBSingleton.h new file mode 100644 index 0000000..55e4010 --- /dev/null +++ b/src/cpp/query_engine/AtomDBSingleton.h @@ -0,0 +1,36 @@ +#ifndef _QUERY_ENGINE_ATOMDBSINGLETON_H +#define _QUERY_ENGINE_ATOMDBSINGLETON_H + +#include +#include "AtomDB.h" + +using namespace std; + +namespace query_engine { + +// ------------------------------------------------------------------------------------------------- +// NOTE TO REVIEWER: +// +// This class will be replaced/integrated by/with classes already implemented in das-atom-db. +// +// I think it's pointless to make any further documentation while we don't make this integrfation. +// ------------------------------------------------------------------------------------------------- + +class AtomDBSingleton { + +public: + + ~AtomDBSingleton() {} + static void init(); + static shared_ptr get_instance(); + +private: + + AtomDBSingleton() {} + static bool initialized; + static shared_ptr atom_db; +}; + +} // namespace query_engine + +#endif // _QUERY_ENGINE_ATOMDBSINGLETON_H diff --git a/src/cpp/query_engine/BUILD b/src/cpp/query_engine/BUILD new file mode 100644 index 0000000..107238a --- /dev/null +++ b/src/cpp/query_engine/BUILD @@ -0,0 +1,23 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "query_engine_lib", + srcs = glob(["*.cc", "query_element/*.cc"]), + hdrs = glob(["*.h", "query_element/*.h"]), + includes = [".", "query_element"], + deps = [ + "//cpp/utils:utils_lib", + "//cpp/hasher:hasher_lib", + "//cpp/attention_broker:attention_broker_server_lib", + "@com_github_singnet_das_node//:atomspacenode", + ], + linkopts = [ + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + +) diff --git a/src/cpp/query_engine/DASNode.cc b/src/cpp/query_engine/DASNode.cc new file mode 100644 index 0000000..0311725 --- /dev/null +++ b/src/cpp/query_engine/DASNode.cc @@ -0,0 +1,574 @@ +#include +#include "DASNode.h" +#include "LinkTemplate.h" +#include "Terminal.h" +#include "Or.h" +#include "RemoteSink.h" + +using namespace query_engine; + +string DASNode::PATTERN_MATCHING_QUERY = "pattern_matching_query"; + +// ------------------------------------------------------------------------------------------------- +// Constructors and destructors + +DASNode::DASNode(const string &node_id) : StarNode(node_id) { + initialize(); + // SERVER +} + +DASNode::DASNode(const string &node_id, const string &server_id) : StarNode(node_id, server_id) { + initialize(); + // CLIENT +} + +DASNode::~DASNode() { +} + +void DASNode::initialize() { + this->first_query_port = 60000; + this->last_query_port = 61999; + string id = this->node_id(); + this->local_host = id.substr(0, id.find(":")); + if (this->is_server) { + this->next_query_port = this->first_query_port; + } else { + this->next_query_port = (this->first_query_port + this->last_query_port) / 2; + } +} + +// ------------------------------------------------------------------------------------------------- +// Public client API + +RemoteIterator *DASNode::pattern_matcher_query( + const vector &tokens, + const string &context, + bool update_attention_broker) { +#ifdef DEBUG + cout << "DASNode::pattern_matcher_query() BEGIN" << endl; + cout << "DASNode::pattern_matcher_query() tokens.size(): " << tokens.size() << endl; + cout << "DASNode::pattern_matcher_query() context: " << context << endl; + cout << "DASNode::pattern_matcher_query() update_attention_broker: " << update_attention_broker << endl; +#endif + if (this->is_server) { + Utils::error("pattern_matcher_query() is not available in DASNode server."); + } + // TODO XXX change this when requestor is set in basic Message + string query_id = next_query_id(); + vector args = {query_id, context, std::to_string(update_attention_broker)}; + args.insert(args.end(), tokens.begin(), tokens.end()); + send(PATTERN_MATCHING_QUERY, args, this->server_id); +#ifdef DEBUG + cout << "DASNode::pattern_matcher_query() END" << endl; +#endif + return new RemoteIterator(query_id); +} + +// ------------------------------------------------------------------------------------------------- +// Public generic methods + +string DASNode::next_query_id() { + unsigned int port = this->next_query_port++; + unsigned int limit; + if (this->is_server) { + limit = ((this->first_query_port + this->last_query_port) / 2) - 1; + if (this->next_query_port > limit) { + this->next_query_port = this->first_query_port; + } + } else { + limit = this->last_query_port; + if (this->next_query_port > limit) { + this->next_query_port = (this->first_query_port + this->last_query_port) / 2; + } + } +#ifdef DEBUG + cout << "DASNode::next_query_id(): " << this->local_host + ":" + std::to_string(port) << endl; +#endif + return this->local_host + ":" + std::to_string(port); +} + +// ------------------------------------------------------------------------------------------------- +// Messages + +shared_ptr DASNode::message_factory(string &command, vector &args) { + std::shared_ptr message = AtomSpaceNode::message_factory(command, args); + if (message) { + return message; + } + if (command == DASNode::PATTERN_MATCHING_QUERY) { + return std::shared_ptr(new PatternMatchingQuery(command, args)); + } + return std::shared_ptr{}; +} + +QueryElement *PatternMatchingQuery::build_link_template( + vector &tokens, + unsigned int cursor, + stack &element_stack) { + + unsigned int arity = std::stoi(tokens[cursor + 2]); + if (element_stack.size() < arity) { + Utils::error("PatternMatchingQuery message: parse error in tokens - too few arguments for LINK_TEMPLATE"); + } + switch (arity) { + // TODO: consider replacing each "case" below by a pre-processor macro call + case 1: { + array targets; + for (unsigned int i = 0; i < 1; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<1>(tokens[cursor + 1], targets, this->context); + } + case 2: { + array targets; + for (unsigned int i = 0; i < 2; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<2>(tokens[cursor + 1], targets, this->context); + } + case 3: { + array targets; + for (unsigned int i = 0; i < 3; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<3>(tokens[cursor + 1], targets, this->context); + } + case 4: { + array targets; + for (unsigned int i = 0; i < 4; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<4>(tokens[cursor + 1], targets, this->context); + } + case 5: { + array targets; + for (unsigned int i = 0; i < 5; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<5>(tokens[cursor + 1], targets, this->context); + } + case 6: { + array targets; + for (unsigned int i = 0; i < 6; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<6>(tokens[cursor + 1], targets, this->context); + } + case 7: { + array targets; + for (unsigned int i = 0; i < 7; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<7>(tokens[cursor + 1], targets, this->context); + } + case 8: { + array targets; + for (unsigned int i = 0; i < 8; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<8>(tokens[cursor + 1], targets, this->context); + } + case 9: { + array targets; + for (unsigned int i = 0; i < 9; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<9>(tokens[cursor + 1], targets, this->context); + } + case 10: { + array targets; + for (unsigned int i = 0; i < 10; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new LinkTemplate<10>(tokens[cursor + 1], targets, this->context); + } + default: { + Utils::error("PatternMatchingQuery message: max supported arity for LINK_TEMPLATE: 10"); + } + } + return NULL; // Just to avoid warnings. This is not actually reachable. +} + + +QueryElement *PatternMatchingQuery::build_and( + vector &tokens, + unsigned int cursor, + stack &element_stack) { + + unsigned int num_clauses = std::stoi(tokens[cursor + 1]); + if (element_stack.size() < num_clauses) { + Utils::error("PatternMatchingQuery message: parse error in tokens - too few arguments for AND"); + } + switch (num_clauses) { + // TODO: consider replacing each "case" below by a pre-processor macro call + case 1: { + array clauses; + for (unsigned int i = 0; i < 1; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<1>(clauses); + } + case 2: { + array clauses; + for (unsigned int i = 0; i < 2; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<2>(clauses); + } + case 3: { + array clauses; + for (unsigned int i = 0; i < 3; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<3>(clauses); + } + case 4: { + array clauses; + for (unsigned int i = 0; i < 4; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<4>(clauses); + } + case 5: { + array clauses; + for (unsigned int i = 0; i < 5; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<5>(clauses); + } + case 6: { + array clauses; + for (unsigned int i = 0; i < 6; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<6>(clauses); + } + case 7: { + array clauses; + for (unsigned int i = 0; i < 7; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<7>(clauses); + } + case 8: { + array clauses; + for (unsigned int i = 0; i < 8; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<8>(clauses); + } + case 9: { + array clauses; + for (unsigned int i = 0; i < 9; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<9>(clauses); + } + case 10: { + array clauses; + for (unsigned int i = 0; i < 10; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new And<10>(clauses); + } + default: { + Utils::error("PatternMatchingQuery message: max supported num_clauses for AND: 10"); + } + } + return NULL; // Just to avoid warnings. This is not actually reachable. +} + +QueryElement *PatternMatchingQuery::build_or( + vector &tokens, + unsigned int cursor, + stack &element_stack) { + + unsigned int num_clauses = std::stoi(tokens[cursor + 1]); + if (element_stack.size() < num_clauses) { + Utils::error("PatternMatchingQuery message: parse error in tokens - too few arguments for OR"); + } + switch (num_clauses) { + // TODO: consider replacing each "case" below by a pre-processor macro call + case 1: { + array clauses; + for (unsigned int i = 0; i < 1; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<1>(clauses); + } + case 2: { + array clauses; + for (unsigned int i = 0; i < 2; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<2>(clauses); + } + case 3: { + array clauses; + for (unsigned int i = 0; i < 3; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<3>(clauses); + } + case 4: { + array clauses; + for (unsigned int i = 0; i < 4; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<4>(clauses); + } + case 5: { + array clauses; + for (unsigned int i = 0; i < 5; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<5>(clauses); + } + case 6: { + array clauses; + for (unsigned int i = 0; i < 6; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<6>(clauses); + } + case 7: { + array clauses; + for (unsigned int i = 0; i < 7; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<7>(clauses); + } + case 8: { + array clauses; + for (unsigned int i = 0; i < 8; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<8>(clauses); + } + case 9: { + array clauses; + for (unsigned int i = 0; i < 9; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<9>(clauses); + } + case 10: { + array clauses; + for (unsigned int i = 0; i < 10; i++) { + clauses[i] = element_stack.top(); + element_stack.pop(); + } + return new Or<10>(clauses); + } + default: { + Utils::error("PatternMatchingQuery message: max supported num_clauses for OR: 10"); + } + } + return NULL; // Just to avoid warnings. This is not actually reachable. +} + +QueryElement *PatternMatchingQuery::build_link( + vector &tokens, + unsigned int cursor, + stack &element_stack) { + + unsigned int arity = std::stoi(tokens[cursor + 2]); + if (element_stack.size() < arity) { + Utils::error("PatternMatchingQuery message: parse error in tokens - too few arguments for LINK"); + } + switch (arity) { + // TODO: consider replacing each "case" below by a pre-processor macro call + case 1: { + array targets; + for (unsigned int i = 0; i < 1; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<1>(tokens[cursor + 1], targets); + } + case 2: { + array targets; + for (unsigned int i = 0; i < 2; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<2>(tokens[cursor + 1], targets); + } + case 3: { + array targets; + for (unsigned int i = 0; i < 3; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<3>(tokens[cursor + 1], targets); + } + case 4: { + array targets; + for (unsigned int i = 0; i < 4; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<4>(tokens[cursor + 1], targets); + } + case 5: { + array targets; + for (unsigned int i = 0; i < 5; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<5>(tokens[cursor + 1], targets); + } + case 6: { + array targets; + for (unsigned int i = 0; i < 6; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<6>(tokens[cursor + 1], targets); + } + case 7: { + array targets; + for (unsigned int i = 0; i < 7; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<7>(tokens[cursor + 1], targets); + } + case 8: { + array targets; + for (unsigned int i = 0; i < 8; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<8>(tokens[cursor + 1], targets); + } + case 9: { + array targets; + for (unsigned int i = 0; i < 9; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<9>(tokens[cursor + 1], targets); + } + case 10: { + array targets; + for (unsigned int i = 0; i < 10; i++) { + targets[i] = element_stack.top(); + element_stack.pop(); + } + return new Link<10>(tokens[cursor + 1], targets); + } + default: { + Utils::error("PatternMatchingQuery message: max supported arity for LINK: 10"); + } + } + return NULL; // Just to avoid warnings. This is not actually reachable. +} + +PatternMatchingQuery::PatternMatchingQuery(string command, vector &tokens) { + +#ifdef DEBUG + cout << "PatternMatchingQuery::PatternMatchingQuery() BEGIN" << endl; +#endif + + stack execution_stack; + stack element_stack; + this->requestor_id = tokens[0]; + this->context = tokens[1]; + this->update_attention_broker = (tokens[2] == "1"); + unsigned int cursor = 3; // TODO XXX: change this when requestor is set in basic Message + unsigned int tokens_count = tokens.size(); + +#ifdef DEBUG + cout << "PatternMatchingQuery::PatternMatchingQuery() tokens_count: " << tokens_count << endl; +#endif + while (cursor < tokens_count) { + execution_stack.push(cursor); + if ((tokens[cursor] == "VARIABLE") || (tokens[cursor] == "AND") || ((tokens[cursor] == "OR"))) { + cursor += 2; + } else { + cursor += 3; + } + } + if (cursor != tokens_count) { + Utils::error("PatternMatchingQuery message: parse error in tokens"); + } + + while (! execution_stack.empty()) { + cursor = execution_stack.top(); + if (tokens[cursor] == "NODE") { + element_stack.push(new Node(tokens[cursor + 1], tokens[cursor + 2])); + } else if (tokens[cursor] == "VARIABLE") { + element_stack.push(new Variable(tokens[cursor + 1])); + } else if (tokens[cursor] == "LINK") { + element_stack.push(build_link(tokens, cursor, element_stack)); + } else if (tokens[cursor] == "LINK_TEMPLATE") { + element_stack.push(build_link_template(tokens, cursor, element_stack)); + } else if (tokens[cursor] == "AND") { + element_stack.push(build_and(tokens, cursor, element_stack)); + } else if (tokens[cursor] == "OR") { + element_stack.push(build_or(tokens, cursor, element_stack)); + } else { + Utils::error("Invalid token " + tokens[cursor] + " in PatternMatchingQuery message"); + } + execution_stack.pop(); + } + + if (element_stack.size() != 1) { + Utils::error("PatternMatchingQuery message: parse error in tokens (trailing elements)"); + } + this->root_query_element = element_stack.top(); + element_stack.pop(); +#ifdef DEBUG + cout << "PatternMatchingQuery::PatternMatchingQuery() END" << endl; +#endif +} + +void PatternMatchingQuery::act(shared_ptr node) { +#ifdef DEBUG + cout << "PatternMatchingQuery::act() BEGIN" << endl; + cout << "PatternMatchingQuery::act() this->requestor_id: " << this->requestor_id << endl; +#endif + auto das_node = dynamic_pointer_cast(node); + + // TODO XXX Remove memory leak + RemoteSink *remote_sink = new RemoteSink( + this->root_query_element, + das_node->next_query_id(), + this->requestor_id, + this->update_attention_broker, + this->context); + +#ifdef DEBUG + cout << "PatternMatchingQuery::act() END" << endl; +#endif +} diff --git a/src/cpp/query_engine/DASNode.h b/src/cpp/query_engine/DASNode.h new file mode 100644 index 0000000..b2f73a2 --- /dev/null +++ b/src/cpp/query_engine/DASNode.h @@ -0,0 +1,84 @@ +#ifndef _QUERY_ENGINE_DASNODE_H +#define _QUERY_ENGINE_DASNODE_H + +#define DEBUG + +#include +#include "StarNode.h" +#include "RemoteIterator.h" + +using namespace std; +using namespace atom_space_node; +using namespace query_element; + +namespace query_engine { + +/** + * + */ +class DASNode : public StarNode { + +public: + + static string PATTERN_MATCHING_QUERY; + + DASNode(const string &node_id); + DASNode(const string &node_id, const string &server_id); + ~DASNode(); + + RemoteIterator *pattern_matcher_query( + const vector &tokens, + const string &context = "", + bool update_attention_broker = false); + string next_query_id(); + + virtual shared_ptr message_factory(string &command, vector &args); + +private: + + void initialize(); + + string local_host; + unsigned int next_query_port; + unsigned int first_query_port; + unsigned int last_query_port; +}; + +class PatternMatchingQuery : public Message { + +public: + + PatternMatchingQuery(string command, vector &tokens); + void act(shared_ptr node); + +private: + + QueryElement *build_link_template( + vector &tokens, + unsigned int cursor, + stack &element_stack); + + QueryElement *build_and( + vector &tokens, + unsigned int cursor, + stack &element_stack); + + QueryElement *build_or( + vector &tokens, + unsigned int cursor, + stack &element_stack); + + QueryElement *build_link( + vector &tokens, + unsigned int cursor, + stack &element_stack); + + QueryElement *root_query_element; + string requestor_id; + string context; + bool update_attention_broker; +}; + +} // namespace query_engine + +#endif // _QUERY_ENGINE_DASNODE_H diff --git a/src/cpp/query_engine/QueryAnswer.cc b/src/cpp/query_engine/QueryAnswer.cc new file mode 100644 index 0000000..57bb20d --- /dev/null +++ b/src/cpp/query_engine/QueryAnswer.cc @@ -0,0 +1,269 @@ +#include "QueryAnswer.h" +#include "Utils.h" +#include +#include +#include + +using namespace query_engine; +using namespace commons; + +// ------------------------------------------------------------------------------------------------- +// Assignment + +Assignment::Assignment() { + this->size = 0; +} + +Assignment::~Assignment() { +} + +bool Assignment::assign(const char *label, const char *value) { + for (unsigned int i = 0; i < this->size; i++) { + // if label is already present, return true iff its value is the same + if (strncmp(label, this->labels[i], MAX_VARIABLE_NAME_SIZE) == 0) { + return (strncmp(value, this->values[i], HANDLE_HASH_SIZE) == 0); + } + } + // label is not present, so makes the assignment and return true + labels[this->size] = label; + values[this->size] = value; + this->size++; + if (this->size == MAX_NUMBER_OF_VARIABLES_IN_QUERY) { + Utils::error( + "Assignment size exceeds the maximal number of allowed variables in a query: " + + std::to_string(MAX_NUMBER_OF_VARIABLES_IN_QUERY)); + } + return true; +} + +bool Assignment::is_compatible(const Assignment &other) { + for (unsigned int i = 0; i < this->size; i++) { + for (unsigned int j = 0; j < other.size; j++) { + if ((strncmp(this->labels[i], other.labels[j], MAX_VARIABLE_NAME_SIZE) == 0) && + (strncmp(this->values[i], other.values[j], HANDLE_HASH_SIZE) != 0)) { + return false; + } + } + } + return true; +} + +void Assignment::copy_from(const Assignment &other) { + this->size = other.size; + unsigned int num_bytes = this->size * sizeof(char *); + memcpy((void *) this->labels, (const void *) other.labels, num_bytes); + memcpy((void *) this->values, (const void *) other.values, num_bytes); +} + +void Assignment::add_assignments(const Assignment &other) { + bool already_contains; + for (unsigned int j = 0; j < other.size; j++) { + already_contains = false; + for (unsigned int i = 0; i < this->size; i++) { + if (strncmp(this->labels[i], other.labels[j], MAX_VARIABLE_NAME_SIZE) == 0) { + already_contains = true; + break; + } + } + if (! already_contains) { + this->labels[this->size] = other.labels[j]; + this->values[this->size] = other.values[j]; + this->size++; + } + } +} + +const char *Assignment::get(const char *label) { + for (unsigned int i = 0; i < this->size; i++) { + if (strncmp(label, this->labels[i], MAX_VARIABLE_NAME_SIZE) == 0) { + return this->values[i]; + } + } + return NULL; +} + +unsigned int Assignment::variable_count() { + return this->size; +} + +string Assignment::to_string() { + string answer = "{"; + for (unsigned int i = 0; i < this->size; i++) { + answer += "(" + string(this->labels[i]) + ": " + string(this->values[i]) + ")"; + if (i != (this->size - 1)) { + answer += ", "; + } + } + answer += "}"; + return answer; +} + +// ------------------------------------------------------------------------------------------------- +// QueryAnswer + + +QueryAnswer::QueryAnswer() : QueryAnswer(0.0) { +} + +QueryAnswer::QueryAnswer(double importance) { + this->importance = importance; + this->handles_size = 0; +} + +QueryAnswer::QueryAnswer(const char *handle, double importance) { + this->importance = importance; + this->handles[0] = handle; + this->handles_size = 1; +} + +QueryAnswer::~QueryAnswer() { +} + +void QueryAnswer::add_handle(const char *handle) { + this->handles[this->handles_size++] = handle; +} + +QueryAnswer *QueryAnswer::copy(QueryAnswer *base) { // Static method + QueryAnswer *copy = new QueryAnswer(base->importance); + copy->assignment.copy_from(base->assignment); + copy->handles_size = base->handles_size; + memcpy( + (void *) copy->handles, + (const void *) base->handles, + base->handles_size * sizeof(char *)); + return copy; +} + +bool QueryAnswer::merge(QueryAnswer *other, bool merge_handles) { + if (this->assignment.is_compatible(other->assignment)) { + this->assignment.add_assignments(other->assignment); + bool already_exist; + if (merge_handles) { + this->importance = fmax(this->importance, other->importance); + for (unsigned int j = 0; j < other->handles_size; j++) { + already_exist = false; + for (unsigned int i = 0; i < this->handles_size; i++) { + if (strncmp(this->handles[i], other->handles[j], HANDLE_HASH_SIZE) == 0) { + already_exist = true; + break; + } + } + if (! already_exist) { + this->handles[this->handles_size++] = other->handles[j]; + } + } + } + return true; + } else { + return false; + } +} + +string QueryAnswer::to_string() { + string answer = "QueryAnswer<" + std::to_string(this->handles_size) + ","; + answer += std::to_string(this->assignment.variable_count()) + "> ["; + for (unsigned int i = 0; i < this->handles_size; i++) { + answer += string(this->handles[i]); + if (i != (this->handles_size - 1)) { + answer += ", "; + } + } + answer += "] " + this->assignment.to_string() + " " + std::to_string(this->importance); + //answer += "] " + this->assignment.to_string(); + return answer; +} + +const string &QueryAnswer::tokenize() { + // char_count is computed to be slightly larger than actually required by assuming + // e.g. 3 digits to represent sizes + char importance_buffer[13]; + sprintf(importance_buffer, "%.10f", this->importance); + unsigned int char_count = + 13 // importance with 10 decimals + space + + 4 // (up to 3 digits) to represent this->handles_size + space + + this->handles_size * (HANDLE_HASH_SIZE + 1) // handles + spaces + + 4 // (up to 3 digits) to represent this->assignment.size + space + + this->assignment.size * (MAX_VARIABLE_NAME_SIZE + HANDLE_HASH_SIZE + 2); // labelhandle + + this->token_representation.clear(); + this->token_representation.reserve(char_count); + string space = " "; + this->token_representation += importance_buffer; + this->token_representation += space; + this->token_representation += std::to_string(this->handles_size); + this->token_representation += space; + for (unsigned int i = 0; i < this->handles_size; i++) { + this->token_representation += handles[i]; + this->token_representation += space; + } + this->token_representation += std::to_string(this->assignment.size); + this->token_representation += space; + for (unsigned int i = 0; i < this->assignment.size; i++) { + this->token_representation += this->assignment.labels[i]; + this->token_representation += space; + this->token_representation += this->assignment.values[i]; + this->token_representation += space; + } + + return this->token_representation; +} + +static inline void read_token( + const char *token_string, + unsigned int &cursor, + char *token, + unsigned int token_size) { + + unsigned int cursor_token = 0; + while (token_string[cursor] != ' ') { + if ((cursor_token == token_size) || (token_string[cursor] == '\0')) { + Utils::error("Invalid token string"); + } + token[cursor_token++] = token_string[cursor++]; + } + token[cursor_token] = '\0'; + cursor++; +} + +void QueryAnswer::untokenize(const string &tokens) { + + const char *token_string = tokens.c_str(); + char number[4]; + char importance[13]; + char handle[HANDLE_HASH_SIZE]; + char label[MAX_VARIABLE_NAME_SIZE]; + + unsigned int cursor = 0; + + read_token(token_string, cursor, importance, 13); + this->importance = std::stod(importance); + + read_token(token_string, cursor, number, 4); + this->handles_size = (unsigned int) std::stoi(number); + if (this->handles_size > MAX_NUMBER_OF_OPERATION_CLAUSES) { + Utils::error("Invalid handles_size: " + std::to_string(this->handles_size) + " untokenizing QueryAnswer"); + } + + for (unsigned int i = 0; i < this->handles_size; i++) { + read_token(token_string, cursor, handle, HANDLE_HASH_SIZE); + this->handles[i] = strdup(handle); + } + + read_token(token_string, cursor, number, 4); + this->assignment.size = (unsigned int) std::stoi(number); + + if (this->assignment.size > MAX_NUMBER_OF_VARIABLES_IN_QUERY) { + Utils::error("Invalid number of assignments: " + std::to_string(this->assignment.size) + " untokenizing QueryAnswer"); + } + + for (unsigned int i = 0; i < this->assignment.size; i++) { + read_token(token_string, cursor, label, MAX_VARIABLE_NAME_SIZE); + read_token(token_string, cursor, handle, HANDLE_HASH_SIZE); + this->assignment.labels[i] = strdup(label); + this->assignment.values[i] = strdup(handle); + } + + if (token_string[cursor] != '\0') { + Utils::error("Invalid token string - invalid text after QueryAnswer definition"); + } +} diff --git a/src/cpp/query_engine/QueryAnswer.h b/src/cpp/query_engine/QueryAnswer.h new file mode 100644 index 0000000..84a0ea9 --- /dev/null +++ b/src/cpp/query_engine/QueryAnswer.h @@ -0,0 +1,267 @@ +#ifndef _QUERY_ENGINE_QUERYANSWER_H +#define _QUERY_ENGINE_QUERYANSWER_H + +#include +#include "expression_hasher.h" + +// If any of these constants are set to numbers greater than 999, we need +// to fix QueryAnswer.tokenize() properly +#define MAX_VARIABLE_NAME_SIZE ((unsigned int) 100) +#define MAX_NUMBER_OF_VARIABLES_IN_QUERY ((unsigned int) 100) +#define MAX_NUMBER_OF_OPERATION_CLAUSES ((unsigned int) 100) + +using namespace std; + +namespace query_engine { + +/** + * This class is the representation of a set of variable assignments. It's a set because each + * variable can be assigned to exactly one value and the order of assignments is irrelevant. + * + * "label1" -> "value1" + * "label2" -> "value2" + * ... + * "labelN" -> "valueN" + */ +class Assignment { + + friend class QueryAnswer; + + public: + + /** + * Basic constructor. + */ + Assignment(); + + /** + * Destructor. + */ + ~Assignment(); + + /** + * Assign a value to a label. + * + * If the label have already an assigned value, assign() will check if the value is the + * same. If not, nothing is done and false is returned. If the value is the same, or + * if the label haven't been assigned yet, true is returned. + * + * @param label Label + * @param value Value to be assigned to the passed label. + * @return true iff the label have no value assigned to it or if the passed value is + * the same as the currently assigned value. + */ + bool assign(const char *label, const char *value); + + /** + * Returns the value assigned to a given label or NULL if no value is assigned to it. + * + * @param label Label to be search for. + * @return The value assigned to a given label or NULL if no value is assigned to it. + */ + const char *get(const char *label); + + /** + * Returns true if the passed Assignment is compatible with this one or false otherwise. + * + * For two Assignments to be considered compatible, all the labels they share must be + * assigned to the same value. Labels defined in only one of the Assignments aren't + * taken into account (so assignments with no common labels will always be compatible). + * + * Empty Assignments are compatible with any other Assignment. + */ + bool is_compatible(const Assignment &other); + + /** + * Shallow copy operation. No allocation of labels or values are performed. + * + * @param other Assignment to be copied from. + */ + void copy_from(const Assignment &other); + + /** + * Adds assignments from other Assignment by making a shallow copy of labels and values. + * + * Labels present in both Assignments are disregarded. So, for instance, if 'this' has: + * + * "label1"-> "value1" + * "label2"-> "value2" + * "label3"-> "value3" + * + * and 'other' has: + * + * "label1"-> "valueX" + * "label2"-> "value2" + * "label4"-> "value4" + * + * The result in 'this' after add_assignment() would be: + * + * "label1"-> "value1" + * "label2"-> "value2" + * "label3"-> "value3" + * "label4"-> "value4" + */ + void add_assignments(const Assignment &other); + + /** + * Returns the number of labels in this assignment. + * + * @return The number of labels in this assignment. + */ + unsigned int variable_count(); + + /** + * Returns a string representation of this Node (mainly for debugging; not optimized to + * production environment). + */ + string to_string(); + + private: + + const char *labels[MAX_NUMBER_OF_VARIABLES_IN_QUERY]; + const char *values[MAX_NUMBER_OF_VARIABLES_IN_QUERY]; + unsigned int size; +}; + +/** + * This is a candidate answer for a query. + * + * Objects of this class are moved through the flow of answers in the query tree. + * They have a set of handles, an Assignment and an attached importance value which + * is calculated using the importance of the elements which have been operated to + * make the answer. + * + * The set of handles represents Links that, together, represent a candidate answer + * to the query, under the constraints of the attached assignment of variables. For instance, + * suppose we have a query like: + * + * AND + * Inheritance + * A + * $v1 + * Inheritance + * $v1 + * B + * + * One possible candidate answer could be the pair of links: + * + * (Inheritance A S) and (Inheritance S B) + * + * with the attached assignment: + * + * $v1 -> S + */ +class QueryAnswer { + +public: + + /** + * Handles which are the constituents of this QueryAnswer. + */ + const char *handles[MAX_NUMBER_OF_OPERATION_CLAUSES]; + + /** + * Number of handles in this QueryAnswer. + */ + unsigned int handles_size; + + /** + * Estimated importance of this QueryAnswer based on the importance of its constituents. + */ + double importance; + + /** + * Underlying assignment of variables which led to this QueryAnswer. + */ + Assignment assignment; + + /** + * Constructor. + * + * @param handle First handle in this QueryAnswer. + * @param importance Estimated importance of this QueryAnswer. + */ + QueryAnswer(const char *handle, double importance); + + /** + * Constructor. + * + * @param importance Estimated importance of this QueryAnswer. + */ + QueryAnswer(double importance); + + /** + * Empty constructor. + */ + QueryAnswer(); + + /** + * Destructor. + */ + ~QueryAnswer(); + + /** + * Adds a handle to this QueryAnswer. + * + * @param handles Handle to be added to this QueryAnswer. + */ + void add_handle(const char *handle); + + /** + * Merges this QueryAnswer with the passed one. + * + * @param other QueryAnswer to be merged in this one. + * @param merge_handles A flag (defaulted to true) to indicate whether the handles should be + * merged (in addition to the assignments). + */ + bool merge(QueryAnswer *other, bool merge_handles = true); + + /** + * Make a shallow copy of the passed QueryAnswer. + * + * A new QueryAnswer object is allocated but the assignment and the handles are shallow-copied. + */ + static QueryAnswer *copy(QueryAnswer *base); + + /** + * Tokenizes the QueryAnswer in a single std::string object (tokens separated by spaces). + * + * The tokenized string looks like this: + * + * N H1 H2 ... HN M L1 V1 L2 V2 ... LM VM + * + * N is the number of handles in the QueryAnswer and M is the number of assignments. Hi are the + * handles and Li Vi are the assignments Li -> Vi + * + * @return A std::string with tokens separated by spaces which can be used to rebuild this QueryAnswer. + */ + const string& tokenize(); + + /** + * Rebuild a QueryAnswer baesd in a list of tokens given in a std::string with tokens separated by spaces. + * + * The tokenized string looks like this: + * + * N H1 H2 ... HN M L1 V1 L2 V2 ... LM VM + * + * N is the number of handles in the QueryAnswer and M is the number of assignments. Hi are the + * handles and Li Vi are the assignments Li -> Vi + * + * @param tokens A std::string with the list of tokens separated by spaces. + */ + void untokenize(const string &tokens); + + /** + * Returns a string representation of this Variable (mainly for debugging; not optimized to + * production environment). + */ + string to_string(); + +private: + + string token_representation; +}; + +} // namespace query_engine + +#endif // _QUERY_ENGINE_QUERYANSWER_H diff --git a/src/cpp/query_engine/QueryNode.cc b/src/cpp/query_engine/QueryNode.cc new file mode 100644 index 0000000..abd0c7d --- /dev/null +++ b/src/cpp/query_engine/QueryNode.cc @@ -0,0 +1,232 @@ +#include "QueryNode.h" +#include "LeadershipBroker.h" +#include "MessageBroker.h" +#include "Utils.h" + +using namespace query_node; +using namespace std; + +string QueryNode::QUERY_ANSWER_TOKENS_FLOW_COMMAND = "query_answer_tokens_flow"; +string QueryNode::QUERY_ANSWER_FLOW_COMMAND = "query_answer_flow"; +string QueryNode::QUERY_ANSWERS_FINISHED_COMMAND = "query_answers_finished"; + +// -------------------------------------------------------------------------------- +// Public methods + +QueryNode::QueryNode( + const string &node_id, + bool is_server, + MessageBrokerType messaging_backend) : + AtomSpaceNode(node_id, LeadershipBrokerType::SINGLE_MASTER_SERVER, messaging_backend) { + + this->is_server = is_server; + this->query_answer_processor = NULL; + this->query_answers_finished_flag = false; + this->shutdown_flag = false; + if (messaging_backend == MessageBrokerType::RAM) { + this->requires_serialization = false; + } else { + this->requires_serialization = true; + } +} + +QueryNode::~QueryNode() { +} + +void QueryNode::graceful_shutdown() { + if (is_shutting_down()) { + return; + } + AtomSpaceNode::graceful_shutdown(); + this->shutdown_flag_mutex.lock(); + this->shutdown_flag = true; + this->shutdown_flag_mutex.unlock(); + if (this->query_answer_processor != NULL) { + this->query_answer_processor->join(); + this->query_answer_processor = NULL; + } +} + +bool QueryNode::is_shutting_down() { + bool answer; + this->shutdown_flag_mutex.lock(); + answer = this->shutdown_flag; + this->shutdown_flag_mutex.unlock(); + return answer; +} + +void QueryNode::query_answers_finished() { + this->query_answers_finished_flag_mutex.lock(); + this->query_answers_finished_flag = true; + this->query_answers_finished_flag_mutex.unlock(); +} + +bool QueryNode::is_query_answers_finished() { + bool answer; + this->query_answers_finished_flag_mutex.lock(); + answer = this->query_answers_finished_flag; + this->query_answers_finished_flag_mutex.unlock(); + return answer; +} + +shared_ptr QueryNode::message_factory(string &command, vector &args) { + std::shared_ptr message = AtomSpaceNode::message_factory(command, args); + if (message) { + return message; + } + if (command == QueryNode::QUERY_ANSWER_FLOW_COMMAND) { + return std::shared_ptr(new QueryAnswerFlow(command, args)); + } else if (command == QueryNode::QUERY_ANSWER_TOKENS_FLOW_COMMAND) { + return std::shared_ptr(new QueryAnswerTokensFlow(command, args)); + } else if (command == QueryNode::QUERY_ANSWERS_FINISHED_COMMAND) { + return std::shared_ptr(new QueryAnswersFinished(command, args)); + } + return std::shared_ptr{}; +} + +void QueryNode::add_query_answer(QueryAnswer *query_answer) { + if (is_query_answers_finished()) { + Utils::error("Invalid addition of new query answer."); + } else { + this->query_answer_queue.enqueue((void *) query_answer); + } +} + +QueryAnswer *QueryNode::pop_query_answer() { + return (QueryAnswer *) this->query_answer_queue.dequeue(); +} + +bool QueryNode::is_query_answers_empty() { + return this->query_answer_queue.empty(); +} + +QueryNodeServer::QueryNodeServer( + const string &node_id, + MessageBrokerType messaging_backend) : + QueryNode(node_id, true, messaging_backend) { + + this->join_network(); + this->query_answer_processor = new thread( + &QueryNodeServer::query_answer_processor_method, + this); +} + +QueryNodeServer::~QueryNodeServer() { + graceful_shutdown(); + if (this->query_answer_processor != NULL) { + this->query_answer_processor->join(); + this->query_answer_processor = NULL; + } +} + +void QueryNodeServer::node_joined_network(const string &node_id) { + this->add_peer(node_id); +} + +string QueryNodeServer::cast_leadership_vote() { + return this->node_id(); +} + +void QueryNodeServer::query_answer_processor_method() { + while (! is_shutting_down()) { + Utils::sleep(); + } +} + +void QueryNodeClient::query_answer_processor_method() { + QueryAnswer *query_answer; + vector args; + bool answers_finished_flag = false; + while (! is_shutting_down()) { + while ((query_answer = (QueryAnswer *) this->query_answer_queue.dequeue()) != NULL) { + if (this->requires_serialization) { + string tokens = query_answer->tokenize(); + args.push_back(tokens); + } else { + args.push_back(to_string((unsigned long) query_answer)); + } + } + if (args.empty()) { + // The order of the AND clauses below matters + if (! answers_finished_flag && is_query_answers_finished() && this->query_answer_queue.empty()) { + this->send(QueryNode::QUERY_ANSWERS_FINISHED_COMMAND, args, this->server_id); + answers_finished_flag = true; + } + } else { + if (this->requires_serialization) { + this->send(QueryNode::QUERY_ANSWER_TOKENS_FLOW_COMMAND, args, this->server_id); + } else { + this->send(QueryNode::QUERY_ANSWER_FLOW_COMMAND, args, this->server_id); + } + args.clear(); + } + Utils::sleep(); + } +} + +QueryNodeClient::QueryNodeClient( + const string &node_id, + const string &server_id, + MessageBrokerType messaging_backend) : + QueryNode(node_id, true, messaging_backend) { + + this->query_answer_processor = new thread( + &QueryNodeClient::query_answer_processor_method, + this); + this->server_id = server_id; + this->add_peer(server_id); + this->join_network(); +} + +QueryNodeClient::~QueryNodeClient() { + graceful_shutdown(); + if (this->query_answer_processor != NULL) { + this->query_answer_processor->join(); + this->query_answer_processor = NULL; + } +} + +void QueryNodeClient::node_joined_network(const string &node_id) { + // do nothing +} + +string QueryNodeClient::cast_leadership_vote() { + return this->server_id; +} + +QueryAnswerFlow::QueryAnswerFlow(string command, vector &args) { + for (auto pointer_string: args) { + QueryAnswer *query_answer = (QueryAnswer *) stoul(pointer_string); + this->query_answers.push_back(query_answer); + } +} + +void QueryAnswerFlow::act(shared_ptr node) { + auto query_node = dynamic_pointer_cast(node); + for (auto query_answer: this->query_answers) { + query_node->add_query_answer(query_answer); + } +} + +QueryAnswerTokensFlow::QueryAnswerTokensFlow(string command, vector &args) { + for (auto tokens: args) { + this->query_answers_tokens.push_back(tokens); + } +} + +void QueryAnswerTokensFlow::act(shared_ptr node) { + auto query_node = dynamic_pointer_cast(node); + for (auto tokens: this->query_answers_tokens) { + QueryAnswer *query_answer = new QueryAnswer(); + query_answer->untokenize(tokens); + query_node->add_query_answer(query_answer); + } +} + +QueryAnswersFinished::QueryAnswersFinished(string command, vector &args) { +} + +void QueryAnswersFinished::act(shared_ptr node) { + auto query_node = dynamic_pointer_cast(node); + query_node->query_answers_finished(); +} diff --git a/src/cpp/query_engine/QueryNode.h b/src/cpp/query_engine/QueryNode.h new file mode 100644 index 0000000..98c63f9 --- /dev/null +++ b/src/cpp/query_engine/QueryNode.h @@ -0,0 +1,124 @@ +#ifndef _QUERY_NODE_QUERYNODE_H +#define _QUERY_NODE_QUERYNODE_H + +#include +#include +#include "AtomSpaceNode.h" +#include "SharedQueue.h" +#include "QueryAnswer.h" + +using namespace std; +using namespace atom_space_node; +using namespace query_engine; + +namespace query_node { + +/** + * + */ +class QueryNode : public AtomSpaceNode { + +public: + + QueryNode( + const string &node_id, + bool is_server, + MessageBrokerType messaging_backend = MessageBrokerType::RAM); + virtual ~QueryNode(); + virtual shared_ptr message_factory(string &command, vector &args); + virtual void graceful_shutdown(); + bool is_shutting_down(); + void query_answers_finished(); + bool is_query_answers_finished(); + void add_query_answer(QueryAnswer *query_answer); + QueryAnswer *pop_query_answer(); + bool is_query_answers_empty(); + virtual void query_answer_processor_method() = 0; + + static string QUERY_ANSWER_FLOW_COMMAND; + static string QUERY_ANSWER_TOKENS_FLOW_COMMAND; + static string QUERY_ANSWERS_FINISHED_COMMAND; + +protected: + + SharedQueue query_answer_queue; + thread *query_answer_processor; + bool requires_serialization; + +private: + + bool is_server; + bool shutdown_flag; + mutex shutdown_flag_mutex; + bool query_answers_finished_flag; + mutex query_answers_finished_flag_mutex; +}; + +class QueryNodeServer : public QueryNode { + +public: + + QueryNodeServer( + const string &node_id, + MessageBrokerType messaging_backend = MessageBrokerType::RAM); + virtual ~QueryNodeServer(); + + void node_joined_network(const string &node_id); + string cast_leadership_vote(); + void query_answer_processor_method(); +}; + +class QueryNodeClient : public QueryNode { + +public: + + QueryNodeClient( + const string &node_id, + const string &server_id, + MessageBrokerType messaging_backend = MessageBrokerType::RAM); + virtual ~QueryNodeClient(); + + void node_joined_network(const string &node_id); + string cast_leadership_vote(); + void query_answer_processor_method(); + +private: + + string server_id; +}; + +class QueryAnswerFlow : public Message { + +public: + + QueryAnswerFlow(string command, vector &args); + void act(shared_ptr node); + +private: + + vector query_answers; +}; + +class QueryAnswerTokensFlow : public Message { + +public: + + QueryAnswerTokensFlow(string command, vector &args); + void act(shared_ptr node); + +private: + + vector query_answers_tokens; +}; + +class QueryAnswersFinished : public Message { + +public: + + QueryAnswersFinished(string command, vector &args); + void act(shared_ptr node); +}; + +} // namespace query_node + +#endif // _QUERY_NODE_QUERYNODE_H diff --git a/src/cpp/query_engine/StarNode.cc b/src/cpp/query_engine/StarNode.cc new file mode 100644 index 0000000..9d1483d --- /dev/null +++ b/src/cpp/query_engine/StarNode.cc @@ -0,0 +1,50 @@ +#include "StarNode.h" +#include "LeadershipBroker.h" +#include "MessageBroker.h" + +using namespace atom_space_node; +using namespace std; + +// ------------------------------------------------------------------------------------------------- +// Constructors and destructors + +StarNode::StarNode( + const string &node_id, + MessageBrokerType messaging_backend) : + AtomSpaceNode(node_id, LeadershipBrokerType::SINGLE_MASTER_SERVER, messaging_backend) { + + this->is_server = true; + this->join_network(); +} + +StarNode::StarNode( + const string &node_id, + const string &server_id, + MessageBrokerType messaging_backend) : + AtomSpaceNode(node_id, LeadershipBrokerType::SINGLE_MASTER_SERVER, messaging_backend) { + + this->server_id = server_id; + this->is_server = false; + this->add_peer(server_id); + this->join_network(); +} + +StarNode::~StarNode() { +} + +// ------------------------------------------------------------------------------------------------- +// DistributedAlgorithmNode virtual API + +void StarNode::node_joined_network(const string &node_id) { + if (this->is_server) { + this->add_peer(node_id); + } +} + +string StarNode::cast_leadership_vote() { + if (this->is_server) { + return this->node_id(); + } else { + return this->server_id; + } +} diff --git a/src/cpp/query_engine/StarNode.h b/src/cpp/query_engine/StarNode.h new file mode 100644 index 0000000..65311ba --- /dev/null +++ b/src/cpp/query_engine/StarNode.h @@ -0,0 +1,79 @@ +#ifndef _QUERY_NODE_STARNODE_H +#define _QUERY_NODE_STARNODE_H + +#include +#include "AtomSpaceNode.h" + +using namespace std; + +namespace atom_space_node { + +/** + * Node in a "star" topology with one single server (which knows every other nodes in the network) + * and N nodes (which know only the server). + * + * Use the different constructors to choose from client or server. + */ +class StarNode : public AtomSpaceNode { + +public: + + // -------------------------------------------------------------------------------------------- + // Constructors and destructors + + /** + * Server constructor. + * + * @param node_id ID of this node in the network. + * @param messaging_backend Type of network communication which will be used by the nodes. + * in the network to exchange messages. Defaulted to GRPC. + */ + StarNode( + const string &node_id, + MessageBrokerType messaging_backend = MessageBrokerType::GRPC); + + /** + * Client constructor. + * + * @param node_id ID of this node in the network. + * @param server_id ID of the server node. + * @param messaging_backend Type of network communication which will be used by the nodes + * in the network to exchange messages. Defaulted to GRPC. + */ + StarNode( + const string &node_id, + const string &server_id, + MessageBrokerType messaging_backend = MessageBrokerType::GRPC); + + /** + * Destructor + */ + virtual ~StarNode(); + + // -------------------------------------------------------------------------------------------- + // AtomSpaceNode virtual API + + /** + * Method called when a new node is inserted in the network after this one has already joined. + * Server nodes will keep track of all newly inserted nodes. Client nodes disregard the info. + * + * @param node_id ID of the newly inserted node. + */ + void node_joined_network(const string &node_id); + + /** + * Method called when a leadershipo election is requested. + * + * Server nodes votes in themselves for leader while client node votes in their server. + */ + string cast_leadership_vote(); + +protected: + + bool is_server; + string server_id; +}; + +} // namespace atom_space_node + +#endif // _QUERY_NODE_STARNODE_H diff --git a/src/cpp/query_engine/query_element/And.h b/src/cpp/query_engine/query_element/And.h new file mode 100644 index 0000000..2447713 --- /dev/null +++ b/src/cpp/query_engine/query_element/And.h @@ -0,0 +1,303 @@ +#ifndef _QUERY_ELEMENT_AND_H +#define _QUERY_ELEMENT_AND_H + +#include +#include +#include "Operator.h" + +using namespace std; + +namespace query_element { + + +/** + * QueryElement representing an AND logic operator. + * + * And operates on N clauses. Each clause can be either a Source or another Operator. + */ +template +class And : public Operator { + +public: + + // -------------------------------------------------------------------------------------------- + // Constructors and destructors + + /** + * Constructor. + * + * @param clauses Array with N clauses (each clause is supposed to be a Source or an Operator). + */ + And(QueryElement **clauses) : Operator(clauses) { + initialize(clauses); + } + + /** + * Constructor. + * + * @param clauses Array with N clauses (each clause is supposed to be a Source or an Operator). + */ + And(const array &clauses) : Operator(clauses) { + initialize((QueryElement **) clauses.data()); + } + + /** + * Destructor. + */ + ~And() { + graceful_shutdown(); + } + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + virtual void setup_buffers() { + Operator::setup_buffers(); + this->operator_thread = new thread(&And::and_operator_method, this); + } + + virtual void graceful_shutdown() { + Operator::graceful_shutdown(); + if (this->operator_thread != NULL) { + this->operator_thread->join(); + this->operator_thread = NULL; + } + } + + // -------------------------------------------------------------------------------------------- + // Private stuff + +private: + + class CandidateRecord { + public: + QueryAnswer *answer[N]; + unsigned int index[N]; + double fitness; + CandidateRecord() { + } + CandidateRecord(const CandidateRecord &other) { + this->fitness = other.fitness; + memcpy((void *) this->index, (const void *) other.index, N * sizeof(unsigned int)); + memcpy( + (void *) this->answer, + (const void *) other.answer, + N * sizeof(QueryAnswer *)); + } + CandidateRecord& operator=(const CandidateRecord &other) { + this->fitness = other.fitness; + memcpy((void *) this->index, (const void *) other.index, N * sizeof(unsigned int)); + memcpy( + (void *) this->answer, + (const void *) other.answer, + N * sizeof(QueryAnswer *)); + return *this; + } + bool operator<(const CandidateRecord &other) const { + return this->fitness < other.fitness; + } + bool operator>(const CandidateRecord &other) const { + return this->fitness > other.fitness; + } + bool operator==(const CandidateRecord &other) const { + for (unsigned int i = 0; i < N; i++) { + if (this->index[i] != other.index[i]) { + return false; + } + } + return true; + } + }; + + struct hash_function { + size_t operator()(const CandidateRecord& record) const { + size_t hash = record.index[0]; + size_t power = 1; + for (unsigned int i = 1; i < N; i++) { + power *= N; + hash += record.index[i] * power; + } + return hash; + } + }; + + vector query_answer[N]; + unsigned int next_input_to_process[N]; + priority_queue border; + unordered_set visited; + bool all_answers_arrived[N]; + bool no_more_answers_to_arrive; + thread *operator_thread; + + void initialize(QueryElement **clauses) { + this->operator_thread = NULL; + for (unsigned int i = 0; i < N; i++) { + this->next_input_to_process[i] = 0; + this->all_answers_arrived[i] = false; + } + this->no_more_answers_to_arrive = false; + this->id = "And("; + for (unsigned int i = 0; i < N; i++) { + this->id += clauses[i]->id; + if (i != (N - 1)) { + this->id += ", "; + } + } + this->id += ")"; + } + + bool ready_to_process_candidate() { + for (unsigned int i = 0; i < N; i++) { + if ((! this->all_answers_arrived[i]) && + (this->query_answer[i].size() <= (this->next_input_to_process[i] + 1))) { + return false; + } + } + return true; + } + + void ingest_newly_arrived_answers() { + if (this->no_more_answers_to_arrive) { + return; + } + QueryAnswer *answer; + unsigned int all_arrived_count = 0; + bool no_new_answer = true; + for (unsigned int i = 0; i < N; i++) { + while ((answer = this->input_buffer[i]->pop_query_answer()) != NULL) { + no_new_answer = false; + this->query_answer[i].push_back(answer); + } + if (this->input_buffer[i]->is_query_answers_empty() && + this->input_buffer[i]->is_query_answers_finished()) { + + this->all_answers_arrived[i] = true; + all_arrived_count++; + } + } + if (all_arrived_count == N) { + this->no_more_answers_to_arrive = true; + } else { + if (no_new_answer) { + Utils::sleep(); + } + } + } + + void operate_candidate(const CandidateRecord &candidate) { + QueryAnswer *new_query_answer = QueryAnswer::copy(candidate.answer[0]); + for (unsigned int i = 1; i < N; i++) { + if (! new_query_answer->merge(candidate.answer[i])) { + delete new_query_answer; + return; + } + } + this->output_buffer->add_query_answer(new_query_answer); + } + + bool processed_all_input() { + if (this->border.size() > 0) { + return false; + } else { + for (unsigned int i = 0; i < N; i++) { + if (this->next_input_to_process[i] < this->query_answer[i].size()) { + return false; + } + } + } + return true; + } + + void expand_border(const CandidateRecord &last_used_candidate) { + CandidateRecord candidate; + unsigned int index_in_queue; + bool abort_candidate; + for (unsigned int new_candidate_count = 0; new_candidate_count < N; new_candidate_count++) { + abort_candidate = false; + candidate.fitness = 1.0; + for (unsigned int answer_queue_index = 0; answer_queue_index < N; answer_queue_index++) { + index_in_queue = last_used_candidate.index[answer_queue_index]; + if (answer_queue_index == new_candidate_count) { + index_in_queue++; + if (index_in_queue < this->query_answer[answer_queue_index].size()) { + if (index_in_queue == this->next_input_to_process[answer_queue_index]) { + this->next_input_to_process[answer_queue_index]++; + } + } else { + abort_candidate = true; + break; + } + } + candidate.answer[answer_queue_index] = + this->query_answer[answer_queue_index][index_in_queue]; + candidate.index[answer_queue_index] = index_in_queue; + candidate.fitness *= candidate.answer[answer_queue_index]->importance; + } + if (abort_candidate) { + continue; + } + if (visited.find(candidate) == visited.end()) { + this->border.push(candidate); + this->visited.insert(candidate); + } + } + } + + void and_operator_method() { + + do { + if (QueryElement::is_flow_finished() || + this->output_buffer->is_query_answers_finished()) { + + return; + } + + do { + if (QueryElement::is_flow_finished()) { + return; + } + ingest_newly_arrived_answers(); + } while (! ready_to_process_candidate()); + + if (processed_all_input()) { + bool all_finished_flag = true; + for (unsigned int i = 0; i < N; i++) { + if (! this->input_buffer[i]->is_query_answers_finished()) { + all_finished_flag = false; + break; + } + } + if (all_finished_flag && + ! this->output_buffer->is_query_answers_finished() && + // processed_all_input() is double-checked on purpose to avoid race condition + processed_all_input()) { + this->output_buffer->query_answers_finished(); + } + Utils::sleep(); + continue; + } + + if (this->border.size() == 0) { + CandidateRecord candidate; + double fitness = 1.0; + for (unsigned int i = 0; i < N; i++) { + candidate.answer[i] = this->query_answer[i][this->next_input_to_process[i]], + candidate.index[i] = this->next_input_to_process[i]; + this->next_input_to_process[i]++; + fitness *= candidate.answer[i]->importance; + } + candidate.fitness = fitness; + this->border.push(candidate); + this->visited.insert(candidate); + } + CandidateRecord candidate = this->border.top(); + operate_candidate(candidate); + expand_border(candidate); + this->border.pop(); + } while (true); + } +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_AND_H diff --git a/src/cpp/query_engine/query_element/Iterator.cc b/src/cpp/query_engine/query_element/Iterator.cc new file mode 100644 index 0000000..4f4f17d --- /dev/null +++ b/src/cpp/query_engine/query_element/Iterator.cc @@ -0,0 +1,26 @@ +#include "Iterator.h" + +using namespace query_element; + +// ------------------------------------------------------------------------------------------------- +// Public methods + +Iterator::Iterator( + QueryElement *precedent, + bool delete_precedent_on_destructor) : + Sink(precedent, "Iterator(" + precedent->id + ")", delete_precedent_on_destructor) { +} + +Iterator::~Iterator() { +} + +bool Iterator::finished() { + // The order of the AND clauses below matters + return ( + this->input_buffer->is_query_answers_finished() && + this->input_buffer->is_query_answers_empty()); +} + +QueryAnswer *Iterator::pop() { + return (QueryAnswer *) this->input_buffer->pop_query_answer(); +} diff --git a/src/cpp/query_engine/query_element/Iterator.h b/src/cpp/query_engine/query_element/Iterator.h new file mode 100644 index 0000000..d8ca3e9 --- /dev/null +++ b/src/cpp/query_engine/query_element/Iterator.h @@ -0,0 +1,58 @@ +#ifndef _QUERY_ELEMENT_ITERATOR_H +#define _QUERY_ELEMENT_ITERATOR_H + +#include "Sink.h" +#include "QueryAnswer.h" + +using namespace std; +using namespace query_engine; + +namespace query_element { + +/** + * Concrete Sink that provides an iterator API to give access to the query answers. + * + * NB This is not a std::iterator as the behavior we'd expect of a std::iterator + * doesn't fit well with the asynchronous nature of QueryElement processing. + * + * Instead, this class provides only two methods: one to pop and return the next + * query answers and another to check if more answers can still be expected. + * + */ +class Iterator : public Sink { + +public: + + /** + * Constructor expects that the QueryElement below in the tree is already constructed. + */ + Iterator(QueryElement *precedent, bool delete_precedent_on_destructor = false); + ~Iterator(); + + // -------------------------------------------------------------------------------------------- + // Public Iterator API + + /** + * Return true when all query answers has been processed AND all the query answers + * that reached this QueryElement has been pop'ed out using the method pop(). + * + * @return true iff all query answers has been processed AND all the query answers + * that reached this QueryElement has been pop'ed out using the method pop(). + */ + bool finished(); + + /** + * Return the next query answer or NULL if none are currently available. + * + * NB a NULL return DOESN'T mean that the query answers are over. It means that there + * are no query answers available now. Because of the asynchronous nature of QueryElement + * processing, more query answers can arrive later. + * + * @return the next query answer or NULL if none are currently available. + */ + QueryAnswer *pop(); +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_ITERATOR_H diff --git a/src/cpp/query_engine/query_element/LinkTemplate.h b/src/cpp/query_engine/query_element/LinkTemplate.h new file mode 100644 index 0000000..c650338 --- /dev/null +++ b/src/cpp/query_engine/query_element/LinkTemplate.h @@ -0,0 +1,497 @@ +#ifndef _QUERY_ELEMENT_LINKTEMPLATE_H +#define _QUERY_ELEMENT_LINKTEMPLATE_H + +#include + +#include "QueryNode.h" +#include "Source.h" +#include "Iterator.h" +#include "And.h" +#include "Terminal.h" +#include "AtomDBSingleton.h" +#include "AtomDBAPITypes.h" +#include "QueryAnswer.h" +#include "expression_hasher.h" +#include "SharedQueue.h" + +#include "AttentionBrokerServer.h" +#include "attention_broker.grpc.pb.h" +#include +#include "attention_broker.pb.h" + +#define MAX_GET_IMPORTANCE_BUNDLE_SIZE ((unsigned int) 100000) + +using namespace std; +using namespace query_engine; +using namespace attention_broker_server; + +namespace query_element { + +/** + * Concrete Source that searches for a pattern in the AtomDB and feeds the QueryElement up in the + * query tree with the resulting links. + * + * A pattern is something like: + * + * Similarity + * Human + * $v1 + * + * In the example, any links of type "Similarity" pointing to Human as the first target would be + * returned. These returned links are then fed into the subsequent QueryElement in the tree. + * + * LinkTemplate query the AtomDB for the links that match the pattern. In addition to this, it + * attaches values for any variables in the pattern and sorts all the AtomDB answers by importance + * (by querying the AttentionBroker) before following up the links (most important ones first). + * + * An arbitrary number of nested levels are allowed. For instance: + * + * Expression + * Symbol A + * Symbol B + * $v1 + * Expression + * Symbol C + * $v2 + * Expression + * $v1 + * $v2 + * Expression + * Symbol X + * Symbol Y + * Symbol Z + * + * Returned links are guaranteed to satisfy all variable settings properly. + */ +template +class LinkTemplate : public Source { + +public: + + // -------------------------------------------------------------------------------------------- + // Constructors and destructors + + /** + * Constructor expects an array of QueryElements which can be Terminals or nested LinkTemplate. + * + * @param type Link type or WILDCARD to indicate that the link type doesn't matter. + * @param targets An array with targets which can each be a Terminal or a nested LinkTemplate. + * @param context An optional string defining the context used by the AttentionBroker to + * consider STI (short term importance). + */ + LinkTemplate( + const string &type, + const array &targets, + const string &context = "") { + + this->context = context; + this->arity = ARITY; + this->type = type; + this->target_template = targets; + this->fetch_finished = false; + this->atom_document = NULL; + this->local_answers = NULL; + this->local_answers_size = 0; + this->local_buffer_processor = NULL; + bool wildcard_flag = (type == AtomDB::WILDCARD); + this->handle_keys[0] = (wildcard_flag ? + (char *) AtomDB::WILDCARD.c_str() : + named_type_hash((char *) type.c_str())); + for (unsigned int i = 1; i <= ARITY; i++) { + // It's safe to get stored shared_ptr's raw pointer here because handle_keys[] + // is used solely in this scope so it's guaranteed that handle will not be freed. + if (targets[i - 1]->is_terminal) { + this->handle_keys[i] = ((Terminal *) targets[i - 1])->handle.get(); + } else { + this->handle_keys[i] = (char *) AtomDB::WILDCARD.c_str(); + this->inner_template.push_back(targets[i - 1]); + } + } + this->handle = shared_ptr(composite_hash(this->handle_keys, ARITY + 1)); + if (! wildcard_flag) { + free(this->handle_keys[0]); + } + // This is correct. id is not necessarily a handle but an identifier. It just happens + // that we want the string for this identifier to be the same as the string representing + // the handle. + this->id = this->handle.get() + std::to_string(LinkTemplate::next_instance_count()); + } + + /** + * Destructor. + */ + virtual ~LinkTemplate() { +#ifdef DEBUG + cout << "LinkTemplate::LinkTemplate() DESTRUCTOR BEGIN" << endl; +#endif + graceful_shutdown(); + local_answers_mutex.lock(); + if (local_answers_size > 0) { + delete [] this->atom_document; + delete [] this->local_answers; + delete [] this->next_inner_answer; + } + local_answers_mutex.unlock(); +#ifdef DEBUG + cout << "LinkTemplate::LinkTemplate() DESTRUCTOR END" << endl; +#endif + } + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + /** + * Gracefully shuts down this QueryElement's processor thread. + */ + virtual void graceful_shutdown() { +#ifdef DEBUG + cout << "LinkTemplate::graceful_shutdown() BEGIN" << endl; +#endif + set_flow_finished(); + if (this->local_buffer_processor != NULL) { + this->local_buffer_processor->join(); + this->local_buffer_processor = NULL; + } + Source::graceful_shutdown(); +#ifdef DEBUG + cout << "LinkTemplate::graceful_shutdown() END" << endl; +#endif + } + + virtual void setup_buffers() { +#ifdef DEBUG + cout << "LinkTemplate::setup_buffers() BEGIN" << endl; +#endif + Source::setup_buffers(); + if (this->inner_template.size() > 0) { + switch(this->inner_template.size()) { + case 1: { + this->inner_template_iterator = shared_ptr(new Iterator( + inner_template[0] + )); + break; + } + case 2: { + this->inner_template_iterator = shared_ptr(new Iterator( + new And<2>({ + inner_template[0], + inner_template[1] + }), + true + )); + break; + } + case 3: { + this->inner_template_iterator = shared_ptr(new Iterator( + new And<3>({ + inner_template[0], + inner_template[1], + inner_template[2] + }), + true + )); + break; + } + case 4: { + this->inner_template_iterator = shared_ptr(new Iterator( + new And<4>({ + inner_template[0], + inner_template[1], + inner_template[2], + inner_template[3] + }), + true + )); + break; + } + default: { + Utils::error("Invalid number of inner templates (> 4) in link template."); + } + } + } + this->local_buffer_processor = new thread( + &LinkTemplate::local_buffer_processor_method, + this); + fetch_links(); +#ifdef DEBUG + cout << "LinkTemplate::setup_buffers() END" << endl; +#endif + } + +private: + + struct less_than_query_answer { + inline bool operator() (const QueryAnswer *qa1, const QueryAnswer *qa2) { + // Reversed check as we want descending sort + return (qa1->importance > qa2->importance); + } + }; + + // -------------------------------------------------------------------------------------------- + // Private methods + + void increment_local_answers_size() { + local_answers_mutex.lock(); + this->local_answers_size++; + local_answers_mutex.unlock(); + } + + unsigned int get_local_answers_size() { + unsigned int answer; + local_answers_mutex.lock(); + answer = this->local_answers_size; + local_answers_mutex.unlock(); + return answer; + } + + void get_importance( + const dasproto::HandleList &handle_list, + dasproto::ImportanceList &importance_list) { + + auto stub = dasproto::AttentionBroker::NewStub(grpc::CreateChannel( + this->attention_broker_address, + grpc::InsecureChannelCredentials())); + + if (handle_list.list_size() <= MAX_GET_IMPORTANCE_BUNDLE_SIZE) { + stub->get_importance(new grpc::ClientContext(), handle_list, &importance_list); + return; + } + +#ifdef DEBUG + cout << "get_importance() paginating" << endl; + unsigned int page_count = 1; +#endif + + dasproto::HandleList small_handle_list; + dasproto::ImportanceList small_importance_list; + unsigned int remaining = handle_list.list_size(); + unsigned int cursor = 0; + while (remaining > 0) { +#ifdef DEBUG + cout << "get_importance() page: " << page_count++ << endl; +#endif + for (unsigned int i = 0; i < MAX_GET_IMPORTANCE_BUNDLE_SIZE; i++) { + if (cursor == handle_list.list_size()) { + break; + } + small_handle_list.add_list(handle_list.list(cursor++)); + remaining--; + } +#ifdef DEBUG + cout << "discharging: " << small_handle_list.list_size() << endl; +#endif + stub->get_importance(new grpc::ClientContext(), small_handle_list, &small_importance_list); + for (unsigned int i = 0; i < small_importance_list.list_size(); i++) { + importance_list.add_list(small_importance_list.list(i)); + } + small_handle_list.clear_list(); + small_importance_list.clear_list(); + } + } + + void fetch_links() { +#ifdef DEBUG + cout << "fetch_links() BEGIN" << endl; + cout << "fetch_links() Pattern handle: " << this->handle << endl; +#endif + shared_ptr db = AtomDBSingleton::get_instance(); + this->fetch_result = db->query_for_pattern(this->handle); + unsigned int answer_count = this->fetch_result->size(); +#ifdef DEBUG + cout << "fetch_links() ac: " << answer_count << endl; +#endif + QueryAnswer *query_answer; + vector fetched_answers; + if (answer_count > 0) { + dasproto::HandleList handle_list; + handle_list.set_context(this->context); + for (unsigned int i = 0; i < answer_count; i++) { + handle_list.add_list(this->fetch_result->get_handle(i)); + } + dasproto::ImportanceList importance_list; + get_importance(handle_list, importance_list); + if (importance_list.list_size() != answer_count) { + Utils::error("Invalid AttentionBroker answer. Size: " + + std::to_string(importance_list.list_size()) + + " Expected size: " + std::to_string(answer_count)); + } + this->atom_document = new shared_ptr[answer_count]; + this->local_answers = new QueryAnswer *[answer_count]; + this->next_inner_answer = new unsigned int[answer_count]; + for (unsigned int i = 0; i < answer_count; i++) { + this->atom_document[i] = db->get_atom_document(this->fetch_result->get_handle(i)); + query_answer = new QueryAnswer(this->fetch_result->get_handle(i), importance_list.list(i)); + const char *s = this->atom_document[i]->get("targets", 0); + for (unsigned int j = 0; j < this->arity; j++) { + if (this->target_template[j]->is_terminal) { + Terminal *terminal = (Terminal *) this->target_template[j]; + if (terminal->is_variable) { + if (! query_answer->assignment.assign( + terminal->name.c_str(), + this->atom_document[i]->get("targets", j))) { + Utils::error( + "Error assigning variable: " + + terminal->name + + " a value: " + + string(this->atom_document[i]->get("targets", j))); + } + } + } + } + fetched_answers.push_back(query_answer); + } + std::sort(fetched_answers.begin(), fetched_answers.end(), less_than_query_answer()); + for (unsigned int i = 0; i < answer_count; i++) { + if (this->inner_template.size() == 0) { + this->local_buffer.enqueue((void *) fetched_answers[i]); + } else { + this->local_answers[i] = fetched_answers[i]; + this->next_inner_answer[i] = 0; + this->increment_local_answers_size(); + } + } + if (this->inner_template.size() == 0) { + set_flow_finished(); + } + } else { + set_flow_finished(); + } +#ifdef DEBUG + cout << "fetch_links() END" << endl; +#endif + } + + bool is_feasible(unsigned int index) { + unsigned int inner_answers_size = inner_answers.size(); + unsigned int cursor = this->next_inner_answer[index]; + while (cursor < inner_answers_size) { + if (this->inner_answers[cursor] != NULL) { + bool passed_first_check = true; + unsigned int arity = this->atom_document[index]->get_size("targets"); + unsigned int target_cursor = 0; + for (unsigned int i = 0; i < arity; i++) { + // Note to reviewer: pointer comparison is correct here + if (this->handle_keys[i + 1] == (char *) AtomDB::WILDCARD.c_str()) { + if (target_cursor > this->inner_answers[cursor]->handles_size) { + Utils::error("Invalid query answer in inner link template match"); + } + if (strncmp( + this->atom_document[index]->get("targets", i), + this->inner_answers[cursor]->handles[target_cursor++], + HANDLE_HASH_SIZE)) { + + passed_first_check = false; + break; + } + } + } + if (passed_first_check && + this->local_answers[index]->merge(this->inner_answers[cursor], false)) { + + this->inner_answers[cursor] = NULL; + return true; + } + } + this->next_inner_answer[index]++; + cursor++; + } + return false; + } + + bool ingest_newly_arrived_answers() { + bool flag = false; + QueryAnswer *query_answer; + while ((query_answer = this->inner_template_iterator->pop()) != NULL) { + this->inner_answers.push_back(query_answer); + flag = true; + } + return flag; + } + + void local_buffer_processor_method() { + if (this->inner_template.size() == 0) { + while (! (this->is_flow_finished() && this->local_buffer.empty())) { + QueryAnswer *query_answer; + while ((query_answer = (QueryAnswer *) this->local_buffer.dequeue()) != NULL) { + this->output_buffer->add_query_answer(query_answer); + } + Utils::sleep(); + } + } else { + while (! this->is_flow_finished()) { + unsigned int size = get_local_answers_size(); + if (ingest_newly_arrived_answers()) { + for (unsigned int i = 0; i < size; i++) { + if (this->local_answers[i] != NULL) { + if (is_feasible(i)) { + this->output_buffer->add_query_answer(this->local_answers[i]); + this->local_answers[i] = NULL; + } else { + if (this->inner_template_iterator->finished()) { + this->local_answers[i] = NULL; + } + } + } + } + } else { + if (this->inner_template_iterator->finished()) { + for (unsigned int i = 0; i < size; i++) { + if (this->local_answers[i] != NULL) { + if (is_feasible(i)) { + this->output_buffer->add_query_answer(this->local_answers[i]); + } + this->local_answers[i] = NULL; + } + } + } else { + Utils::sleep(); + } + } + bool finished_flag = true; + for (unsigned int i = 0; i < size; i++) { + if (this->local_answers[i] != NULL) { + finished_flag = false; + break; + } + } + if (this->inner_template_iterator->finished()) { + set_flow_finished(); + } + } + } + this->output_buffer->query_answers_finished(); + } + + static unsigned int next_instance_count() { + static unsigned int instance_count = 0; + return instance_count++; + } + +private: + + string type; + array target_template; + unsigned int arity; + shared_ptr handle; + char *handle_keys[ARITY + 1]; + shared_ptr fetch_result; + vector> atom_documents; + vector inner_template; + SharedQueue local_buffer; + thread *local_buffer_processor; + bool fetch_finished; + mutex fetch_finished_mutex; + shared_ptr target_buffer[ARITY]; + shared_ptr inner_template_iterator; + shared_ptr *atom_document; + QueryAnswer **local_answers; + unsigned int *next_inner_answer; + vector inner_answers; + unsigned int local_answers_size; + mutex local_answers_mutex; + string context; +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_LINKTEMPLATE_H diff --git a/src/cpp/query_engine/query_element/Operator.h b/src/cpp/query_engine/query_element/Operator.h new file mode 100644 index 0000000..f6e3686 --- /dev/null +++ b/src/cpp/query_engine/query_element/Operator.h @@ -0,0 +1,121 @@ +#ifndef _QUERY_ELEMENT_OPERATOR_H +#define _QUERY_ELEMENT_OPERATOR_H + +#include +#include "QueryElement.h" +#include "QueryAnswer.h" + +using namespace std; + +namespace query_element { + +/** + * Superclass for elements which represent logic operators on LinkTemplate results (e.g. AND, + * OR and NOT). + * + * Operator adds the required QueryNode elements to connect either with: + * + * - one or more QueryElement downstream in the query tree (each of them can be either + * Operator or SOurce). + * - one QueryElement upstream in the query tree which can be another Operator or a Sink. + */ +template +class Operator : public QueryElement { + +public: + + // -------------------------------------------------------------------------------------------- + // Constructors and destructors + + /** + * Constructor. + * + * @param clauses Array of QueryElement, each of them a clause in the operation. + */ + Operator(const array &clauses) { + initialize((QueryElement **) clauses.data()); + } + + /** + * Constructor. + * + * @param clauses Array of QueryElement, each of them a clause in the operation. + */ + Operator(QueryElement **clauses) { + initialize(clauses); + } + + /** + * Destructor. + */ + ~Operator() { + this->graceful_shutdown(); + } + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + /** + * Sets up buffers for communication between this operator and its upstream and downstream + * QueryElements. Initializes a single QueryNodeClient for the upstream connection and + * N QueryNodeServer elements for the downstream connections, each corresponding to a clause + * in the operation. + */ + virtual void setup_buffers() { + if (this->subsequent_id == "") { + Utils::error("Invalid empty parent id"); + } + if (this->id == "") { + Utils::error("Invalid empty id"); + } + + this->output_buffer = shared_ptr(new QueryNodeClient(this->id, this->subsequent_id)); + string server_node_id; + for (unsigned int i = 0; i < N; i++) { + server_node_id = this->id + "_" + to_string(i); + this->input_buffer[i] = shared_ptr(new QueryNodeServer(server_node_id)); + this->precedent[i]->subsequent_id = server_node_id; + this->precedent[i]->setup_buffers(); + } + + } + + /** + * Gracefully shuts down the QueryNodes attached to the upstream and downstream communication + * in the query tree. + */ + virtual void graceful_shutdown() { + if (is_flow_finished()) { + return; + } + for (unsigned int i = 0; i < N; i++) { + this->precedent[i]->graceful_shutdown(); + } + set_flow_finished(); + this->output_buffer->graceful_shutdown(); + for (unsigned int i = 0; i < N; i++) { + this->input_buffer[i]->graceful_shutdown(); + } + } + +protected: + + QueryElement *precedent[N]; + shared_ptr input_buffer[N]; + shared_ptr output_buffer; + +private: + + void initialize(QueryElement **clauses) { + if (N > MAX_NUMBER_OF_OPERATION_CLAUSES) { + Utils::error("Operation exceeds max number of clauses: " + to_string(N)); + } + for (unsigned int i = 0; i < N; i++) { + precedent[i] = clauses[i]; + } + } +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_OPERATOR_H diff --git a/src/cpp/query_engine/query_element/Or.h b/src/cpp/query_engine/query_element/Or.h new file mode 100644 index 0000000..a60a0d9 --- /dev/null +++ b/src/cpp/query_engine/query_element/Or.h @@ -0,0 +1,210 @@ +#ifndef _QUERY_ELEMENT_OR_H +#define _QUERY_ELEMENT_OR_H + +#include +#include +#include "Operator.h" + +using namespace std; + +namespace query_element { + + +/** + * QueryElement representing an OR logic operator. + * + * Or operates on N clauses. Each clause can be either a Source or another Operator. + */ +template +class Or : public Operator { + +public: + + // -------------------------------------------------------------------------------------------- + // Constructors and destructors + + /** + * Constructor. + * + * @param clauses Array with N clauses (each clause is supposed to be a Source or an Operator). + */ + Or(QueryElement **clauses) : Operator(clauses) { + initialize(clauses); + } + + /** + * Constructor. + * + * @param clauses Array with N clauses (each clause is supposed to be a Source or an Operator). + */ + Or(const array &clauses) : Operator(clauses) { + initialize((QueryElement **) clauses.data()); + } + + /** + * Destructor. + */ + ~Or() { + graceful_shutdown(); + } + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + virtual void setup_buffers() { + Operator::setup_buffers(); + this->operator_thread = new thread(&Or::or_operator_method, this); + } + + virtual void graceful_shutdown() { + Operator::graceful_shutdown(); + if (this->operator_thread != NULL) { + this->operator_thread->join(); + this->operator_thread = NULL; + } + } + + // -------------------------------------------------------------------------------------------- + // Private stuff + +private: + + vector query_answer[N]; + unsigned int next_input_to_process[N]; + bool all_answers_arrived[N]; + bool no_more_answers_to_arrive; + thread *operator_thread; + + void initialize(QueryElement **clauses) { + this->operator_thread = NULL; + for (unsigned int i = 0; i < N; i++) { + this->next_input_to_process[i] = 0; + this->all_answers_arrived[i] = false; + } + this->no_more_answers_to_arrive = false; + this->id = "Or("; + for (unsigned int i = 0; i < N; i++) { + this->id += clauses[i]->id; + if (i != (N - 1)) { + this->id += ", "; + } + } + this->id += ")"; + } + + bool ready_to_process_candidate() { + for (unsigned int i = 0; i < N; i++) { + if ((! this->all_answers_arrived[i]) && + (this->query_answer[i].size() <= (this->next_input_to_process[i] + 1))) { + return false; + } + } + return true; + } + + void ingest_newly_arrived_answers() { + if (this->no_more_answers_to_arrive) { + return; + } + QueryAnswer *answer; + unsigned int all_arrived_count = 0; + bool no_new_answer = true; + for (unsigned int i = 0; i < N; i++) { + while ((answer = this->input_buffer[i]->pop_query_answer()) != NULL) { + no_new_answer = false; + this->query_answer[i].push_back(answer); + } + if (this->input_buffer[i]->is_query_answers_empty() && + this->input_buffer[i]->is_query_answers_finished()) { + + this->all_answers_arrived[i] = true; + all_arrived_count++; + } + } + if (all_arrived_count == N) { + this->no_more_answers_to_arrive = true; + } else { + if (no_new_answer) { + Utils::sleep(); + } + } + } + + bool processed_all_input() { + for (unsigned int i = 0; i < N; i++) { + if (this->next_input_to_process[i] < this->query_answer[i].size()) { + return false; + } + } + return true; + } + + unsigned int select_answer() { + unsigned int best_index; + double best_importance = -1; + for (unsigned int i = 0; i < N; i++) { + if (this->next_input_to_process[i] < this->query_answer[i].size()) { + if (this->query_answer[i][this->next_input_to_process[i]]->importance > best_importance) { + best_importance = this->query_answer[i][this->next_input_to_process[i]]->importance; + best_index = i; + } + } + } + if (best_importance < 0) { + Utils::error ("Invalid state in OR operation"); + } + return best_index; + } + + void or_operator_method() { + + do { + if (QueryElement::is_flow_finished() || + this->output_buffer->is_query_answers_finished()) { + + return; + } + + do { + if (QueryElement::is_flow_finished()) { + return; + } + ingest_newly_arrived_answers(); + } while (! ready_to_process_candidate()); + + cout << "XXXXXXX 1" << endl; + if (processed_all_input()) { + cout << "XXXXXXX 2" << endl; + bool all_finished_flag = true; + for (unsigned int i = 0; i < N; i++) { + if (! this->input_buffer[i]->is_query_answers_finished()) { + all_finished_flag = false; + break; + } + } + cout << "XXXXXXX 3" << endl; + if (all_finished_flag && + ! this->output_buffer->is_query_answers_finished() && + // processed_all_input() is double-checked on purpose to avoid race condition + processed_all_input()) { + this->output_buffer->query_answers_finished(); + } + cout << "XXXXXXX 4" << endl; + Utils::sleep(); + continue; + } + cout << "XXXXXXX 5" << endl; + + unsigned int selected_clause = select_answer(); + cout << "XXXXXXX 6" << endl; + QueryAnswer *selected_query_answer = this->query_answer[selected_clause][this->next_input_to_process[selected_clause]++]; + cout << std::to_string(selected_clause) << ": " << selected_query_answer->to_string() << endl; + this->output_buffer->add_query_answer(selected_query_answer); + cout << "XXXXXXX 7" << endl; + } while (true); + } +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_OR_H diff --git a/src/cpp/query_engine/query_element/QueryElement.cc b/src/cpp/query_engine/query_element/QueryElement.cc new file mode 100644 index 0000000..af326ac --- /dev/null +++ b/src/cpp/query_engine/query_element/QueryElement.cc @@ -0,0 +1,31 @@ +#include "QueryElement.h" + +using namespace query_element; + +// ------------------------------------------------------------------------------------------------ +// Constructors and destructors + +QueryElement::QueryElement() { + this->flow_finished = false; + this->is_terminal = false; +} + +QueryElement::~QueryElement() { +} + +// ------------------------------------------------------------------------------------------------ +// Protected methods + +void QueryElement::set_flow_finished() { + this->flow_finished_mutex.lock(); + this->flow_finished = true; + this->flow_finished_mutex.unlock(); +} + +bool QueryElement::is_flow_finished() { + bool answer; + this->flow_finished_mutex.lock(); + answer = this->flow_finished; + this->flow_finished_mutex.unlock(); + return answer; +} diff --git a/src/cpp/query_engine/query_element/QueryElement.h b/src/cpp/query_engine/query_element/QueryElement.h new file mode 100644 index 0000000..fc702b3 --- /dev/null +++ b/src/cpp/query_engine/query_element/QueryElement.h @@ -0,0 +1,118 @@ +#ifndef _QUERY_ELEMENT_QUERYELEMENT_H +#define _QUERY_ELEMENT_QUERYELEMENT_H + +#include +#include +#include "Utils.h" +#include "QueryNode.h" + +#define DEBUG + +using namespace std; +using namespace query_node; +using namespace commons; + +namespace query_element { + +/** + * Basic element in the class hierarchy which represents boolean logical expression involving + * nodes, links and patterns. + * + * Boolean logical expressions are formed by logical operators (AND, OR, NOT) and + * operands (Node, Link and LinkTemplate). Nested expression are allowed. AND and OR may operate + * on any number (> 1) of arguments while NOT takes a single argument. + * + * Nodes are defined by type+name. Links are defined by type+targets. LinkTemplates are defined + * like Links, where the Link type and any number of targets may be wildcards (actually, wildcards + * are named variables which are unified as the query is executed). LinkTemplates can also be + * nested, i.e., one of the targets of a LinkTemplate can be another LinkTemplate. + * + * There's no limit in the number of nesting levels of LinkTemplates or boolean expressions. + * + * A query can be understood as a tree whose nodes are QueryElements. Internal nodes are + * logical operators and leaves are either Links or LinkTemplates (nested or not). + * + * The query engine we implement here uses the Nodes/Links values that satisfy the leaves in this + * tree and flows them up through the internal nodes (logical operators) until they reach the root + * of the tree. In this path, some links are dropped because they don't satisfy the properties + * required by the operators or they don't satisfy a proper unification in the set of variables. + * + * Links that reach the root of the tree are considered actual query answers. + * + * Each QueryElement is an element in a distributed algorithm, with one or more threads processing + * its inputs and generating outputs according to the logic of each element. A communication + * framework is used to flow the links up through the tree using our DistributedAlgorithmNode + * which is essentially a framework to implement the basic functionalities required by a + * distributed algorithm. Since this framework allows communication either intra-process and + * extra-process (in the same machine or in different ones), we can have QueryElements of the + * same tree (i.e. of the same query) being processed in different machines or all of them in the + * same machine (either in the same process or in different processes). + */ +class QueryElement { + +public: + + string id; + string subsequent_id; + + /** + * Basic constructor which solely initialize variables. + */ + QueryElement(); + + /** + * Destructor. + */ + virtual ~QueryElement(); + + // -------------------------------------------------------------------------------------------- + // API to be extended by concrete subclasses + + /** + * Setup QueryNodes used by concrete implementations of QueryElements. This method is called + * after all ids and other topological-related setup in the query tree is finished. + */ + virtual void setup_buffers() = 0; + + /** + * Synchronously request this QueryElement to shutdown any threads it may have spawned. + */ + virtual void graceful_shutdown() = 0; + + /** + * Indicates whether this QueryElement is a Terminal (i.e. Node, Link or Variable). + */ + bool is_terminal; + +protected: + + /** + * Return true iff this QueryElement have finished its work in the flow of links up through + * the query tree. + * + * When this method return true, it means that all the QueryElements below than in the chain + * have already provided all the links they are supposed to and this QueryElement have already + * processed all of them and delivered all the links that are supposed to pass through the flow + * to the upper element in the tree. In other words, this QueryElement have no further work + * to do. + * + * @return true iff this QueryElement have finished its work in the flow of links up througth + * the query tree. + */ + bool is_flow_finished(); + + /** + * Sets a flag to indicate that this QueryElement have finished its work in the query. See + * comments in method is_flow_finished(). + */ + void set_flow_finished(); + +private: + + bool flow_finished; + mutex flow_finished_mutex; +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_QUERYELEMENT_H diff --git a/src/cpp/query_engine/query_element/RemoteIterator.cc b/src/cpp/query_engine/query_element/RemoteIterator.cc new file mode 100644 index 0000000..ec4785c --- /dev/null +++ b/src/cpp/query_engine/query_element/RemoteIterator.cc @@ -0,0 +1,39 @@ +#include "RemoteIterator.h" + +using namespace query_element; + +// ------------------------------------------------------------------------------------------------ +// Constructors and destructors + +RemoteIterator::RemoteIterator(const string &local_id) { + this->local_id = local_id; + setup_buffers(); +} + +RemoteIterator::~RemoteIterator() { + graceful_shutdown(); +} + +// ------------------------------------------------------------------------------------------------- +// Public methods + +void RemoteIterator::setup_buffers() { + this->remote_input_buffer = shared_ptr(new QueryNodeServer( + this->local_id, + MessageBrokerType::GRPC)); +} + +void RemoteIterator::graceful_shutdown() { + this->remote_input_buffer->graceful_shutdown(); +} + +bool RemoteIterator::finished() { + // The order of the AND clauses below matters + return ( + this->remote_input_buffer->is_query_answers_finished() && + this->remote_input_buffer->is_query_answers_empty()); +} + +QueryAnswer *RemoteIterator::pop() { + return (QueryAnswer *) this->remote_input_buffer->pop_query_answer(); +} diff --git a/src/cpp/query_engine/query_element/RemoteIterator.h b/src/cpp/query_engine/query_element/RemoteIterator.h new file mode 100644 index 0000000..d5558ba --- /dev/null +++ b/src/cpp/query_engine/query_element/RemoteIterator.h @@ -0,0 +1,77 @@ +#ifndef _QUERY_ELEMENT_REMOTEITERATOR_H +#define _QUERY_ELEMENT_REMOTEITERATOR_H + +#include "QueryElement.h" + +using namespace std; + +namespace query_element { + +/** + * A special case of QueryElement because RemoteIterator is not actually an element of the + * query tree itself but rather a utility class used to remotely connect to the sink of a query + * tree (RemoteSink). + * + * Basically, the goal of this class is to allow a caller to request a query execution remotely + * and iterate through the results using the RemoteIterator. + * + * NB Like Iterator in this same package, this is not a std::iterator as the behavior we'd expect + * of a std::iterator doesn't fit well with the asynchronous nature of QueryElement processing. + * Instead, this class provides only two methods: one to pop and return the next + * query answers and another to check if more answers can still be expected. + */ +class RemoteIterator : public QueryElement { + +public: + + /** + * Constructor. + * + * @param local_id The id of this element in the network which connects to the RemoteSink. + * Typically is something like "host:port". + */ + RemoteIterator(const string &local_id); + + /** + * Destructor. + */ + ~RemoteIterator(); + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + virtual void graceful_shutdown(); + virtual void setup_buffers(); + + // -------------------------------------------------------------------------------------------- + // Iterator API + + /** + * Return true when all query answers has been processed AND all the query answers + * that reached this QueryElement has been pop'ed out using the method pop(). + * + * @return true iff all query answers has been processed AND all the query answers + * that reached this QueryElement has been pop'ed out using the method pop(). + */ + bool finished(); + + /** + * Return the next query answer or NULL if none are currently available. + * + * NB a NULL return DOESN'T mean that the query answers are over. It means that there + * are no query answers available now. Because of the asynchronous nature of QueryElement + * processing, more query answers can arrive later. + * + * @return the next query answer or NULL if none are currently available. + */ + QueryAnswer *pop(); + +private: + + shared_ptr remote_input_buffer; + string local_id; +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_REMOTEITERATOR_H diff --git a/src/cpp/query_engine/query_element/RemoteSink.cc b/src/cpp/query_engine/query_element/RemoteSink.cc new file mode 100644 index 0000000..1133051 --- /dev/null +++ b/src/cpp/query_engine/query_element/RemoteSink.cc @@ -0,0 +1,302 @@ +#include +#include +#include +#include "RemoteSink.h" +#include "AtomDBSingleton.h" +#include "AtomDBAPITypes.h" + +#include "AttentionBrokerServer.h" +#include "attention_broker.grpc.pb.h" +#include +#include "attention_broker.pb.h" + +#define MAX_CORRELATIONS_WITHOUT_STIMULATE ((unsigned int) 1000) +#define MAX_STIMULATE_COUNT ((unsigned int) 1) + +using namespace query_element; + +string RemoteSink::DEFAULT_ATTENTION_BROKER_PORT = "37007"; + +// ------------------------------------------------------------------------------------------------- +// Constructors and destructors + +RemoteSink::RemoteSink( + QueryElement *precedent, + const string &local_id, + const string &remote_id, + bool update_attention_broker_flag, + const string &context, + bool delete_precedent_on_destructor) : + Sink(precedent, "RemoteSink(" + precedent->id + ")", delete_precedent_on_destructor, false) { +#ifdef DEBUG + cout << "RemoteSink::RemoteSink() BEGIN" << endl; + cout << "RemoteSink::RemoteSink() local_id: " << local_id << endl; + cout << "RemoteSink::RemoteSink() remote_id: " << remote_id << endl; +#endif + + this->attention_broker_address = "localhost:" + RemoteSink::DEFAULT_ATTENTION_BROKER_PORT; + this->query_context = context; + this->local_id = local_id; + this->remote_id = remote_id; + this->queue_processor = NULL; + this->attention_broker_postprocess = NULL; + this->update_attention_broker_flag = update_attention_broker_flag; + RemoteSink::setup_buffers(); + Sink::setup_buffers(); + this->queue_processor = new thread(&RemoteSink::queue_processor_method, this); + if (this->update_attention_broker_flag) { + this->attention_broker_postprocess = new \ + thread(&RemoteSink::attention_broker_postprocess_method, this); + } +#ifdef DEBUG + cout << "RemoteSink::RemoteSink() END" << endl; +#endif +} + +RemoteSink::~RemoteSink() { + graceful_shutdown(); +} + +// ------------------------------------------------------------------------------------------------- +// Public methods + +void RemoteSink::setup_buffers() { +#ifdef DEBUG + cout << "RemoteSink::setup_buffers() BEGIN" << endl; +#endif + this->remote_output_buffer = shared_ptr(new QueryNodeClient( + this->local_id, + this->remote_id, + MessageBrokerType::GRPC)); +#ifdef DEBUG + cout << "RemoteSink::setup_buffers() END" << endl; +#endif +} + +void RemoteSink::graceful_shutdown() { +#ifdef DEBUG + cout << "RemoteSink::graceful_shutdown() BEGIN" << endl; +#endif + Sink::graceful_shutdown(); + set_flow_finished(); + set_attention_broker_postprocess_finished(); + if (this->queue_processor != NULL) { + this->queue_processor->join(); + } + if (this->attention_broker_postprocess != NULL) { + this->attention_broker_postprocess->join(); + } + this->remote_output_buffer->graceful_shutdown(); +#ifdef DEBUG + cout << "RemoteSink::graceful_shutdown() END" << endl; +#endif +} + +// ------------------------------------------------------------------------------------------------- +// Private methods + +void RemoteSink::queue_processor_method() { +#ifdef DEBUG + cout << "RemoteSink::queue_processor_method() BEGIN" << endl; +#endif + do { + if (is_flow_finished() || + (this->input_buffer->is_query_answers_finished() && + this->input_buffer->is_query_answers_empty())) { + + break; + } + bool idle_flag = true; + QueryAnswer *query_answer; + while ((query_answer = this->input_buffer->pop_query_answer()) != NULL) { + this->remote_output_buffer->add_query_answer(query_answer); + if (this->update_attention_broker_flag) { + this->attention_broker_queue.enqueue((void *) query_answer); + //update_attention_broker((QueryAnswer *) query_answer); + } + idle_flag = false; + } + if (idle_flag) { + Utils::sleep(); + } + } while (true); +#ifdef DEBUG + cout << "RemoteSink::queue_processor_method() ready to return" << endl; +#endif + this->remote_output_buffer->query_answers_finished(); + set_flow_finished(); + set_attention_broker_postprocess_finished(); +#ifdef DEBUG + cout << "RemoteSink::queue_processor_method() END" << endl; +#endif +} + +/* +static bool visit_function(HandleTrie::TrieNode *node, void *data) { + ((unordered_map *) data)->insert({ + node->suffix, + ((AccumulatorValue *) node->value)->count + }); + return false; +} +*/ + +void RemoteSink::attention_broker_postprocess_method() { + + // GRPC stuff + + // TODO: XXX Review allocation performance in all this method + set single_answer; // Auxiliary set of handles. + unordered_map joint_answer; // Auxiliary joint count of handles. + //HandleTrie *joint_answer = new HandleTrie(HANDLE_HASH_SIZE - 1); + + // Protobuf data structures + dasproto::HandleList *handle_list; // will contain single_answer (can't be used directly + // because it's a list, not a set. + dasproto::HandleCount handle_count; // Counting of how many times each handle appeared + // in all single_entry + dasproto::Ack *ack; // Command return + + shared_ptr db = AtomDBSingleton::get_instance(); + shared_ptr query_result; + stack execution_stack; + unsigned int weight_sum; + unsigned int correlated_count = 0; + unsigned int stimulated_count = 0; + +#ifdef DEBUG + unsigned int count_total_processed = 0; +#endif + + //handle_list.set_context(this->query_context); + do { + if (is_attention_broker_postprocess_finished() || + (is_flow_finished() && this->attention_broker_queue.empty())) { + break; + } + bool idle_flag = true; + QueryAnswer *query_answer; + string handle; + unsigned int count; + while ((query_answer = (QueryAnswer *) this->attention_broker_queue.dequeue()) != NULL) { + if (stimulated_count == MAX_STIMULATE_COUNT) { + continue; + } +#ifdef DEBUG + count_total_processed++; + if ((count_total_processed % 1000) == 0) { + cout << "RemoteSink::attention_broker_postprocess_method() count_total_processed: " << count_total_processed << endl; + } +#endif + for (unsigned int i = 0; i < query_answer->handles_size; i++) { + execution_stack.push(string(query_answer->handles[i])); + } + while (! execution_stack.empty()) { + handle = execution_stack.top(); + execution_stack.pop(); + // Updates single_answer (correlation) + single_answer.insert(handle); + // Updates joint answer (stimulation) + if (joint_answer.find(handle) != joint_answer.end()) { + count = joint_answer[handle] + 1; + } else { + count = 1; + } + joint_answer[handle] = count; + //joint_answer->insert(handle, new AccumulatorValue()); + // Gets targets and stack them + query_result = db->query_for_targets((char *) handle.c_str()); + if (query_result != NULL) { // if handle is link + unsigned int query_result_size = query_result->size(); + for (unsigned int i = 0; i < query_result_size; i++) { + execution_stack.push(string(query_result->get_handle(i))); + } + } + } + //handle_list.mutable_list()->Clear(); + //handle_list.clear_list(); + auto stub = dasproto::AttentionBroker::NewStub(grpc::CreateChannel( + this->attention_broker_address, + grpc::InsecureChannelCredentials())); // XXXXX Move this up + handle_list = new dasproto::HandleList(); + handle_list->set_context(this->query_context); + for (auto handle_it: single_answer) { + handle_list->add_list(handle_it); + } + single_answer.clear(); + ack = new dasproto::Ack(); +#ifdef DEBUG + //cout << "RemoteSink::attention_broker_postprocess_method() requesting CORRELATE" << endl; +#endif + stub->correlate(new grpc::ClientContext(), *handle_list, ack); + if (ack->msg() != "CORRELATE") { + Utils::error("Failed GRPC command: AttentionBroker::correlate()"); + } + idle_flag = false; + if (++correlated_count == MAX_CORRELATIONS_WITHOUT_STIMULATE) { + correlated_count = 0; + for (auto const& pair: joint_answer) { + (*handle_count.mutable_map())[pair.first] = pair.second; + weight_sum += pair.second; + } + (*handle_count.mutable_map())["SUM"] = weight_sum; + ack = new dasproto::Ack(); + auto stub = dasproto::AttentionBroker::NewStub(grpc::CreateChannel( + this->attention_broker_address, + grpc::InsecureChannelCredentials())); +#ifdef DEBUG + cout << "RemoteSink::attention_broker_postprocess_method() requesting STIMULATE" << endl; +#endif + handle_count.set_context(this->query_context); + stub->stimulate(new grpc::ClientContext(), handle_count, ack); + stimulated_count++; + if (ack->msg() != "STIMULATE") { + Utils::error("Failed GRPC command: AttentionBroker::stimulate()"); + } + joint_answer.clear(); + } + } + if (idle_flag) { + Utils::sleep(); + } + } while (true); + //joint_answer->traverse(true, &visit_function, &joint_answer_map); + if (correlated_count > 0) { + weight_sum = 0; + for (auto const& pair: joint_answer) { + (*handle_count.mutable_map())[pair.first] = pair.second; + weight_sum += pair.second; + } + (*handle_count.mutable_map())["SUM"] = weight_sum; + ack = new dasproto::Ack(); + auto stub = dasproto::AttentionBroker::NewStub(grpc::CreateChannel( + this->attention_broker_address, + grpc::InsecureChannelCredentials())); +#ifdef DEBUG + cout << "RemoteSink::attention_broker_postprocess_method() requesting STIMULATE" << endl; +#endif + handle_count.set_context(this->query_context); + stub->stimulate(new grpc::ClientContext(), handle_count, ack); + stimulated_count++; + if (ack->msg() != "STIMULATE") { + Utils::error("Failed GRPC command: AttentionBroker::stimulate()"); + } + } + //delete joint_answer; + set_attention_broker_postprocess_finished(); +} + +void RemoteSink::set_attention_broker_postprocess_finished() { + this->attention_broker_postprocess_finished_mutex.lock(); + this->attention_broker_postprocess_finished = true; + this->attention_broker_postprocess_finished_mutex.unlock(); +} + +bool RemoteSink::is_attention_broker_postprocess_finished() { + bool answer; + this->attention_broker_postprocess_finished_mutex.lock(); + answer = this->attention_broker_postprocess_finished; + this->attention_broker_postprocess_finished_mutex.unlock(); + return answer; +} + diff --git a/src/cpp/query_engine/query_element/RemoteSink.h b/src/cpp/query_engine/query_element/RemoteSink.h new file mode 100644 index 0000000..17cfb86 --- /dev/null +++ b/src/cpp/query_engine/query_element/RemoteSink.h @@ -0,0 +1,102 @@ +#ifndef _QUERY_ELEMENT_REMOTESINK_H +#define _QUERY_ELEMENT_REMOTESINK_H + +#include "Sink.h" +#include "SharedQueue.h" +//#include "HandleTrie.h" + +using namespace std; +//using namespace attention_broker_server; + +namespace query_element { + +/** + * A special sink which forwards the query results to a remote QueryElement (e.g. a RemoteIterator). + */ +class RemoteSink: public Sink { + +public: + + /** + * Constructor. + * + * @param precedent QueryElement just below in the query tree. + * @param local_id ID of this element in the network connecting to the remote + * peer (typically "host:port"). + * @param remote_id network ID of the remote peer (typically "host:port"). + * @param delete_precedent_on_destructor If true, the destructor of this QueryElement will + * also destruct the passed precedent QueryElement (defaulted to false). + */ + RemoteSink( + QueryElement *precedent, + const string &local_id, + const string &remote_id, + bool update_attention_broker_flag = false, + const string &context = "", + bool delete_precedent_on_destructor = false); + + /** + * Destructor. + */ + ~RemoteSink(); + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + /** + * Setups QueryNode elements related to the communication with the remote element. + */ + virtual void setup_buffers(); + + /** + * Gracefully shuts down the queue processor thread and the remote communication QueryNodes + * present in this QueryElement. + */ + virtual void graceful_shutdown(); + + /** + * Sources tipically need to communicate with the AttentionBroker in order to sort links + * by importance. AttentionBroker is supposed to be running in the same machine as all + * Source elements so only a port number is required. Here we provide a default value + * in the case none is passed in constructor. + */ + static string DEFAULT_ATTENTION_BROKER_PORT; + +private: + + shared_ptr remote_output_buffer; + string local_id; + string remote_id; + thread *queue_processor; + thread *attention_broker_postprocess; + SharedQueue attention_broker_queue; + bool attention_broker_postprocess_finished; + mutex attention_broker_postprocess_finished_mutex; + string attention_broker_address; + bool update_attention_broker_flag; + string query_context; + + void queue_processor_method(); + void attention_broker_postprocess_method(); + bool is_attention_broker_postprocess_finished(); + void set_attention_broker_postprocess_finished(); + //void update_attention_broker(QueryAnswer *query_answer); + +}; + +/* +class AccumulatorValue: public HandleTrie::TrieValue { +public: + unsigned int count; + AccumulatorValue() { + this->count = 1; + } + void merge(TrieValue *other) { + count += ((AccumulatorValue *) other)->count; + } +}; +*/ + +} // namespace query_element + +#endif // _QUERY_ELEMENT_REMOTESINK_H diff --git a/src/cpp/query_engine/query_element/Sink.cc b/src/cpp/query_engine/query_element/Sink.cc new file mode 100644 index 0000000..4cb3ea6 --- /dev/null +++ b/src/cpp/query_engine/query_element/Sink.cc @@ -0,0 +1,47 @@ +#include "Sink.h" + +using namespace query_element; + +// ------------------------------------------------------------------------------------------------ +// Constructors and destructors + +Sink::Sink( + QueryElement *precedent, + const string &id, + bool delete_precedent_on_destructor, + bool setup_buffers_flag) { + + this->precedent = precedent; + this->id = id; + this->delete_precedent_on_destructor = delete_precedent_on_destructor; + if (setup_buffers_flag) { + setup_buffers(); + } +} + +Sink::~Sink() { + this->input_buffer->graceful_shutdown(); + if (this->delete_precedent_on_destructor) { + delete this->precedent; + } +} + +// ------------------------------------------------------------------------------------------------ +// Public methods + +void Sink::setup_buffers() { + if (this->subsequent_id != "") { + Utils::error("Invalid non-empty subsequent id: " + this->subsequent_id); + } + if (this->id == "") { + Utils::error("Invalid empty id"); + } + this->input_buffer = shared_ptr(new QueryNodeServer(this->id)); + this->precedent->subsequent_id = this->id; + this->precedent->setup_buffers(); +} + +void Sink::graceful_shutdown() { + this->input_buffer->graceful_shutdown(); + this->precedent->graceful_shutdown(); +} diff --git a/src/cpp/query_engine/query_element/Sink.h b/src/cpp/query_engine/query_element/Sink.h new file mode 100644 index 0000000..de36e88 --- /dev/null +++ b/src/cpp/query_engine/query_element/Sink.h @@ -0,0 +1,70 @@ +#ifndef _QUERY_ELEMENT_SINK_H +#define _QUERY_ELEMENT_SINK_H + +#include "QueryElement.h" + +using namespace std; + +namespace query_element { + +/** + * Superclass for elements that represent the root in a query tree of QueryElement. + * + * It's a "sink" in the sense of being an element where the flow of links stops, going + * nowhere further. + * + * Sink adds the required DistributedAlgorithmNode (actually a specialized version of it + * named QueryNode) and exposes a public API to interact with it transparently. Basically, + * a server version of QueryNode (i.e. a ServerQueryNode) is setup to communicate with + * a remote ClientQueryNode which is located in the QueryElement just below in the query tree. + */ +class Sink : public QueryElement { + +public: + + /** + * Constructor expects that the QueryElement below in the tree is already constructed. + * + * @param precedent QueryElement just below in the query tree. + * @param id Unique id for this QueryElement. + * @param delete_precedent_on_destructor If true, the destructor of this QueryElement will + * also destruct the passed precedent QueryElement (defaulted to false). + */ + Sink( + QueryElement *precedent, + const string &id, + bool delete_precedent_on_destructor = false, + bool setup_buffers_flag = true); + + /** + * Destructor. + */ + virtual ~Sink(); + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + /** + * Gracefully shuts down the QueryNode. + */ + virtual void graceful_shutdown(); + + /** + * Setup a ServerQueryNode to commnunicate with one or more QueryElement just below in the + * query tree. + */ + virtual void setup_buffers(); + +protected: + + shared_ptr input_buffer; + QueryElement *precedent; + +private: + + bool delete_precedent_on_destructor; +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_SINK_H diff --git a/src/cpp/query_engine/query_element/Source.cc b/src/cpp/query_engine/query_element/Source.cc new file mode 100644 index 0000000..518535e --- /dev/null +++ b/src/cpp/query_engine/query_element/Source.cc @@ -0,0 +1,36 @@ +#include "Source.h" + +using namespace query_element; + +string Source::DEFAULT_ATTENTION_BROKER_PORT = "37007"; + +// ------------------------------------------------------------------------------------------------ +// Constructors and destructors + +Source::Source(const string &attention_broker_address) { + this->attention_broker_address = attention_broker_address; +} + +Source::Source() : Source("localhost:" + Source::DEFAULT_ATTENTION_BROKER_PORT) { +} + +Source::~Source() { + this->output_buffer->graceful_shutdown(); +} + +// ------------------------------------------------------------------------------------------------ +// Public methods + +void Source::setup_buffers() { + if (this->subsequent_id == "") { + Utils::error("Invalid empty parent id"); + } + if (this->id == "") { + Utils::error("Invalid empty id"); + } + this->output_buffer = shared_ptr(new QueryNodeClient(this->id, this->subsequent_id)); +} + +void Source::graceful_shutdown() { + this->output_buffer->graceful_shutdown(); +} diff --git a/src/cpp/query_engine/query_element/Source.h b/src/cpp/query_engine/query_element/Source.h new file mode 100644 index 0000000..d8a8aa0 --- /dev/null +++ b/src/cpp/query_engine/query_element/Source.h @@ -0,0 +1,67 @@ +#ifndef _QUERY_ELEMENT_SOURCE_H +#define _QUERY_ELEMENT_SOURCE_H + +#include "QueryElement.h" + +using namespace std; + +namespace query_element { + +/** + * Superclass for elements that represent leaves in the query tree of QueryElement. + * + * Source adds the required DistributedAlgorithmNode (actually a specialized version of it + * named QueryNode) and exposes a public API to interact with it transparently. Basically, + * a client version of QueryNode (i.e. a ClientQueryNode) is setup to communicate with + * a remote ServerQueryNode which is located in the QueryElement just above in the query tree. + */ +class Source : public QueryElement { + +public: + + /** + * Sources tipically need to communicate with the AttentionBroker in order to sort links + * by importance. AttentionBroker is supposed to be running in the same machine as all + * Source elements so only a port number is required. Here we provide a default value + * in the case none is passed in constructor. + */ + static string DEFAULT_ATTENTION_BROKER_PORT; + + /** + * Constructor which also sets a value for AttentionBroker address + */ + Source(const string &attention_broker_address); + + /** + * Basic empty constructor. + */ + Source(); + + /** + * Destructor. + */ + virtual ~Source(); + + // -------------------------------------------------------------------------------------------- + // QueryElement API + + /** + * Gracefully shuts down the QueryNode. + */ + virtual void graceful_shutdown(); + + /** + * Setup a ClientQueryNode to commnunicate with the upper QueryElement. + */ + virtual void setup_buffers(); + +protected: + + string attention_broker_address; + shared_ptr output_buffer; + QueryElement *subsequent; +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_SOURCE_H diff --git a/src/cpp/query_engine/query_element/Terminal.h b/src/cpp/query_engine/query_element/Terminal.h new file mode 100644 index 0000000..555ea22 --- /dev/null +++ b/src/cpp/query_engine/query_element/Terminal.h @@ -0,0 +1,211 @@ +#ifndef _QUERY_ELEMENT_TERMINAL_H +#define _QUERY_ELEMENT_TERMINAL_H + +#include +#include +#include "QueryElement.h" +#include "AtomDB.h" +#include "expression_hasher.h" + +using namespace std; +using namespace query_engine; + +namespace query_element { + +// ------------------------------------------------------------------------------------------------- +// Abstract Terminal superclass + +/** + * A QueryElement which represents terminals (i.e. Nodes, Links and Variables) in the query tree. + */ +class Terminal : public QueryElement { + +protected: + + /** + * Protected constructor. + */ + Terminal() : QueryElement() { + this->handle = shared_ptr{}; + this->is_variable = false; + this->is_terminal = true; // overrrides QueryElement default + } + +public: + + /** + * Destructor. + */ + ~Terminal() {}; + + /** + * Empty implementation. There are no QueryNode element to setup. + */ + void virtual setup_buffers() {} + + /** + * Empty implementation. There are no QueryNode element or local thread to shut down. + */ + void virtual graceful_shutdown() {} + + /** + * Returns a string representation of this Terminal (mainly for debugging; not optimized to + * production environment). + */ + virtual string to_string() = 0; + + /** + * A flag to indicate whether this Terminal is a Variable or not. + */ + bool is_variable; + + /** + * Handle of the terminal. + */ + shared_ptr handle; + + /** + * Name of the terminal. + * + * Actually, only Nodes and Variables have names; Links' name is an empty string. + */ + string name; +}; + +// ------------------------------------------------------------------------------------------------- +// Node + +/** + * QueryElement which represents a node. + */ +class Node : public Terminal { + +public: + + /** + * Constructor. + * + * @param type Type of the node. + * @param name Name of the node. + */ + Node(const string &type, const string &name) : Terminal() { + this->type = type; + this->name = name; + this->handle = shared_ptr(terminal_hash((char *) type.c_str(), (char *) name.c_str())); + } + + /** + * Returns a string representation of this Node (mainly for debugging; not optimized to + * production environment). + */ + string to_string() { + return "<" + this->type + ", " + this->name + ", " + string(this->handle.get()) + ">"; + } + + /** + * Type of this node. + */ + string type; +}; + +// ------------------------------------------------------------------------------------------------- +// Link + +/** + * QueryElement which represents a link. + */ +template +class Link : public Terminal { + +public: + + /** + * Constructor. + * + * @param type Type of the Link. + * @params targets Array with targets of the Link. Targets are supposed to be + * handles (i.e. strings). No nesting of Nodes or other Links are allowed. + */ + Link(const string &type, const array &targets) : Terminal() { + this->name = ""; + this->type = type; + this->targets = targets; + this->arity = ARITY; + char *handle_keys[ARITY + 1]; + handle_keys[0] = (char *) named_type_hash((char *) type.c_str()); + for (unsigned int i = 1; i < (ARITY + 1); i++) { + if (targets[i - 1]->is_terminal && ! ((Terminal *) targets[i - 1])->is_variable) { + handle_keys[i] = ((Terminal *) targets[i - 1])->handle.get(); + } else { + Utils::error("Invalid Link definition"); + } + } + this->handle = shared_ptr(composite_hash(handle_keys, ARITY + 1)); + free(handle_keys[0]); + } + + /** + * Returns a string representation of this Node (mainly for debugging; not optimized to + * production environment). + */ + string to_string() { + string answer = "(" + this->type + ", ["; + for (unsigned int i = 0; i < this->arity; i++) { + answer += ((Terminal *) this->targets[i])->to_string(); + if (i != (this->arity - 1)) { + answer += ", "; + } + } + answer += "], " + string(this->handle.get()) + ")"; + return answer; + } + + /** + * Type of the Link + */ + string type; + + /** + * Arity of the Link. + */ + unsigned int arity = ARITY; + + /** + * Targets of the Link. + */ + array targets; +}; + +// ------------------------------------------------------------------------------------------------- +// Variable + +/** + * QueryElement which represents a variable. + */ +class Variable : public Terminal { + +public: + + /** + * Constructor. + * + * @param name Name of the Variable. + */ + Variable(const string &name) : Terminal() { + this->name = name; + this->handle = shared_ptr(strdup((char *) AtomDB::WILDCARD.c_str())); + this->is_variable = true; + } + + /** + * Returns a string representation of this Variable (mainly for debugging; not optimized to + * production environment). + */ + string to_string() { + return "$(" + this->name + ")"; + } +}; + +} // namespace query_element + +#endif // _QUERY_ELEMENT_TERMINAL_H diff --git a/src/cpp/tests/BUILD b/src/cpp/tests/BUILD new file mode 100644 index 0000000..03b76ca --- /dev/null +++ b/src/cpp/tests/BUILD @@ -0,0 +1,326 @@ +cc_test( + name = "shared_queue_test", + srcs = ["shared_queue_test.cc"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "request_selector_test", + srcs = ["request_selector_test.cc"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "attention_broker_server_test", + srcs = ["attention_broker_server_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "worker_threads_test", + srcs = ["worker_threads_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "handle_trie_test", + srcs = ["handle_trie_test.cc", "test_utils.cc", "test_utils.h"], + size = "medium", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "hebbian_network_test", + srcs = ["hebbian_network_test.cc", "test_utils.cc", "test_utils.h"], + size = "medium", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "//cpp/utils:utils_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "hebbian_network_updater_test", + srcs = ["hebbian_network_updater_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "stimulus_spreader_test", + srcs = ["stimulus_spreader_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "@com_github_singnet_das_proto//:attention_broker_cc_grpc", + "@com_github_grpc_grpc//:grpc++", + "@com_github_grpc_grpc//:grpc++_reflection", + "//cpp/attention_broker:attention_broker_server_lib", + "@mbedcrypto//:lib", + ], + linkstatic = 1 +) + +cc_test( + name = "link_template_test", + srcs = ["link_template_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + #"-L/opt/3rd-party/mbedcrypto", + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "nested_link_template_test", + srcs = ["nested_link_template_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "query_answer_test", + srcs = ["query_answer_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + ], + linkstatic = 1 +) + +cc_test( + name = "and_operator_test", + srcs = ["and_operator_test.cc", "test_utils.cc", "test_utils.h"], + size = "medium", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + #"-L/opt/3rd-party/mbedcrypto", + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "iterator_test", + srcs = ["iterator_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + #"-L/opt/3rd-party/mbedcrypto", + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "remote_sink_iterator_test", + srcs = ["remote_sink_iterator_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "das_node_test", + srcs = ["das_node_test.cc", "test_utils.cc", "test_utils.h"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkopts = [ + "-lmbedcrypto", + "-L/usr/local/lib", + "-lhiredis_cluster", + "-lhiredis", + "-lmongocxx", + "-lbsoncxx", + ], + linkstatic = 1 +) + +cc_test( + name = "query_node_test", + srcs = ["query_node_test.cc"], + size = "small", + copts = [ + "-Iexternal/gtest/googletest/include", + "-Iexternal/gtest/googletest", + ], + deps = [ + "@com_github_google_googletest//:gtest_main", + "//cpp/query_engine:query_engine_lib", + ], + linkstatic = 1 +) diff --git a/src/cpp/tests/and_operator_test.cc b/src/cpp/tests/and_operator_test.cc new file mode 100644 index 0000000..2d8f43d --- /dev/null +++ b/src/cpp/tests/and_operator_test.cc @@ -0,0 +1,251 @@ +#include +#include +#include "gtest/gtest.h" + +#include "Source.h" +#include "Sink.h" +#include "QueryAnswer.h" +#include "And.h" +#include "test_utils.h" + +using namespace query_engine; +using namespace query_element; + +#define SLEEP_DURATION ((unsigned int) 1000) + +class TestSource : public Source { + + public: + + TestSource(unsigned int count) { + this->id = "TestSource_" + to_string(count); + } + + ~TestSource() { + } + + void add( + const char *handle, + double importance, + const array &labels, + const array &values, + bool sleep_flag = true) { + + QueryAnswer *query_answer = new QueryAnswer(handle, importance); + for (unsigned int i = 0; i < labels.size(); i++) { + query_answer->assignment.assign(labels[i], values[i]); + } + this->output_buffer->add_query_answer(query_answer); + if (sleep_flag) { + Utils::sleep(SLEEP_DURATION); + } + } + + void query_answers_finished() { + return this->output_buffer->query_answers_finished(); + } +}; + +class TestSink : public Sink { + public: + TestSink(QueryElement *precedent) : + Sink(precedent, "TestSink(" + precedent->id + ")") { + } + ~TestSink() { + } + bool empty() { return this->input_buffer->is_query_answers_empty(); } + bool finished() { return this->input_buffer->is_query_answers_finished(); } + QueryAnswer *pop() { return this->input_buffer->pop_query_answer(); } +}; + +void check_query_answer( + string tag, + QueryAnswer *query_answer, + double importance, + unsigned int handles_size, + const array &handles) { + + cout << "check_query_answer(" + tag + ")" << endl; + EXPECT_TRUE(double_equals(query_answer->importance, importance)); + EXPECT_EQ(query_answer->handles_size, 2); + for (unsigned int i = 0; i < handles_size; i++) { + EXPECT_TRUE(strcmp(query_answer->handles[i], handles[i]) == 0); + } +} + +TEST(AndOperator, basics) { + + TestSource source1(1); + TestSource source2(2); + And<2> and_operator({&source1, &source2}); + TestSink sink(&and_operator); + QueryAnswer *query_answer; + + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + + // -------------------------------------------------- + // Expected processing order: + // h1: 0.5, 0.4, 0.3 + // h2: 0.3, 0.2, 0.1 + // 0 0 - 0.15 + // 1 0 - 0.12 INVALID + // 0 1 - 0.10 + // 2 0 - 0.09 + // 1 1 - 0.08 + // 2 1 - 0.06 + // 0 2 - 0.05 + // 1 2 - 0.04 + // 2 2 - 0.03 + // -------------------------------------------------- + + source1.add("h1_0", 0.5, {"v1_0"}, {"1"}); + source2.add("h2_0", 0.3, {"v1_1"}, {"2"}); + source2.add("h2_1", 0.2, {"v2_1"}, {"1"}); + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + source1.add("h1_1", 0.4, {"v1_1"}, {"1"}); + EXPECT_FALSE(sink.empty()); EXPECT_FALSE(sink.finished()); + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + check_query_answer("1", query_answer, 0.5, 2, {"h1_0", "h2_0"}); + EXPECT_TRUE(strcmp(query_answer->assignment.get("v1_0"), "1") == 0); + EXPECT_TRUE(strcmp(query_answer->assignment.get("v1_1"), "2") == 0); + source1.add("h1_2", 0.3, {"v1_2"}, {"1"}); + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + source2.add("h2_2", 0.1, {"v2_2"}, {"1"}); + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + source1.query_answers_finished(); + EXPECT_TRUE(sink.empty()); EXPECT_FALSE(sink.finished()); + source2.query_answers_finished(); + Utils::sleep(SLEEP_DURATION); + EXPECT_FALSE(sink.empty()); EXPECT_TRUE(sink.finished()); + + // {"h1_1", "h2_0"} is not popped because it's invalid + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("3", query_answer, 0.5, 2, {"h1_0", "h2_1"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("4", query_answer, 0.3, 2, {"h1_2", "h2_0"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("5", query_answer, 0.4, 2, {"h1_1", "h2_1"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("6", query_answer, 0.3, 2, {"h1_2", "h2_1"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("7", query_answer, 0.5, 2, {"h1_0", "h2_2"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + check_query_answer("8", query_answer, 0.4, 2, {"h1_1", "h2_2"}); + + EXPECT_FALSE((query_answer = sink.pop()) == NULL); + EXPECT_TRUE(sink.empty()); + check_query_answer("9", query_answer, 0.3, 2, {"h1_2", "h2_2"}); + Utils::sleep(SLEEP_DURATION); + + EXPECT_TRUE(sink.empty()); EXPECT_TRUE(sink.finished()); +} + +TEST(AndOperator, operation_logic) { + + class ImportanceFitnessPair { + public: + double importance; + double fitness; + ImportanceFitnessPair() {} + ImportanceFitnessPair(const ImportanceFitnessPair &other) { + this->importance = other.importance; + this->fitness = other.fitness; + } + ImportanceFitnessPair& operator=(const ImportanceFitnessPair &other) { + this->importance = other.importance; + this->fitness = other.fitness; + return *this; + } + bool operator<(const ImportanceFitnessPair &other) const { + return this->fitness < other.fitness; + } + bool operator>(const ImportanceFitnessPair &other) const { + return this->fitness > other.fitness; + } + }; + + cout << "SETUP" << endl; + + unsigned int clause_count = 3; + unsigned int link_count = 100; + array, 3> importance; + priority_queue fitness_heap; + ImportanceFitnessPair pair; + QueryAnswer *query_answer; + TestSource *source[3]; + for (unsigned int clause = 0; clause < clause_count; clause++) { + source[clause] = new TestSource(clause); + } + And<3> *and_operator = new And<3>((QueryElement **) source); + TestSink *sink = new TestSink(and_operator); + + for (unsigned int clause = 0; clause < clause_count; clause++) { + for (unsigned int link = 0; link < link_count; link++) { + importance[clause][link] = random_importance(); + } + std::sort( + std::begin(importance[clause]), + std::end(importance[clause]), + std::greater{}); + } + + cout << "QUEUES POPULATION" << endl; + + for (unsigned int clause = 0; clause < clause_count; clause++) { + for (unsigned int link = 0; link < link_count; link++) { + source[clause]->add( + random_handle().c_str(), + importance[clause][link], + {"v"}, + {"1"}, + false); + } + source[clause]->query_answers_finished(); + } + + cout << "MATRIX POPULATION" << endl; + + for (unsigned int i = 0; i < link_count; i++) { + for (unsigned int j = 0; j < link_count; j++) { + for (unsigned int k = 0; k < link_count; k++) { + pair.importance = importance[0][i] > importance[1][j] ? importance[0][i] : importance[1][j]; + pair.importance = importance[2][k] > pair.importance ? importance[2][k] : pair.importance; + pair.fitness = importance[0][i] * importance[1][j] * importance[2][k]; + fitness_heap.push(pair); + } + } + } + + Utils::sleep(5000); + cout << "TEST CHECKS" << endl; + + unsigned int count = 0; + while (! (sink->empty() && sink->finished())) { + if (sink->empty()) { + Utils::sleep(); + continue; + } + EXPECT_FALSE((query_answer = sink->pop()) == NULL); + pair = fitness_heap.top(); + cout << count << " CHECK: " << query_answer->importance << " " << pair.importance << " (" << pair.fitness << ")" << endl; + EXPECT_TRUE(double_equals(query_answer->importance, pair.importance)); + fitness_heap.pop(); + count++; + } + + Utils::sleep(5000); + cout << "TEAR DOWN" << endl; + + delete sink; + delete and_operator; + for (unsigned int clause = 0; clause < clause_count; clause++) { + delete source[clause]; + } +} diff --git a/src/cpp/tests/attention_broker_server_test.cc b/src/cpp/tests/attention_broker_server_test.cc new file mode 100644 index 0000000..42856a7 --- /dev/null +++ b/src/cpp/tests/attention_broker_server_test.cc @@ -0,0 +1,84 @@ +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "common.pb.h" +#include "attention_broker.grpc.pb.h" +#include "attention_broker.pb.h" + +#include "AttentionBrokerServer.h" +#include "Utils.h" +#include "test_utils.h" + +using namespace attention_broker_server; +using namespace commons; + +bool importance_equals(ImportanceType importance, double v2) { + double v1 = (double) importance; + return fabs(v2 - v1) < 0.001; +} + +TEST(AttentionBrokerTest, basics) { + + AttentionBrokerServer service; + dasproto::Empty empty; + dasproto::HandleCount handle_count; + dasproto::HandleList handle_list; + dasproto::Ack ack; + dasproto::ImportanceList importance_list; + ServerContext *context = NULL; + + service.ping(context, &empty, &ack); + EXPECT_EQ(ack.msg(), "PING"); + service.stimulate(context, &handle_count, &ack); + EXPECT_EQ(ack.msg(), "STIMULATE"); + service.correlate(context, &handle_list, &ack); + EXPECT_EQ(ack.msg(), "CORRELATE"); + service.get_importance(context, &handle_list, &importance_list); + EXPECT_EQ(importance_list.list_size(), 0); +} + +TEST(AttentionBrokerTest, get_importance) { + + string *handles = build_handle_space(4); + + AttentionBrokerServer service; + dasproto::HandleList handle_list0; + dasproto::HandleList handle_list1; + dasproto::HandleList handle_list2; + dasproto::HandleCount handle_count; + dasproto::Ack ack; + dasproto::ImportanceList importance_list1; + dasproto::ImportanceList importance_list2; + ServerContext *context = NULL; + + (*handle_count.mutable_map())[handles[0]] = 1; + (*handle_count.mutable_map())[handles[1]] = 1; + (*handle_count.mutable_map())["SUM"] = 2; + + handle_list0.add_list(handles[0]); + handle_list0.add_list(handles[1]); + handle_list0.add_list(handles[2]); + handle_list0.add_list(handles[3]); + handle_list1.add_list(handles[0]); + handle_list1.add_list(handles[1]); + handle_list2.add_list(handles[2]); + handle_list2.add_list(handles[3]); + + service.correlate(context, &handle_list0, &ack); + Utils::sleep(1000); + service.stimulate(context, &handle_count, &ack); + Utils::sleep(1000); + service.get_importance(context, &handle_list1, &importance_list1); + service.get_importance(context, &handle_list2, &importance_list2); + + EXPECT_TRUE(importance_list1.list(0) > 0.4); + EXPECT_TRUE(importance_list1.list(1) > 0.4); + EXPECT_TRUE(importance_list2.list(0) < 0.1); + EXPECT_TRUE(importance_list2.list(1) < 0.1); +} diff --git a/src/cpp/tests/das_node_test.cc b/src/cpp/tests/das_node_test.cc new file mode 100644 index 0000000..0644f20 --- /dev/null +++ b/src/cpp/tests/das_node_test.cc @@ -0,0 +1,154 @@ +#include +#include "gtest/gtest.h" + +#include "DASNode.h" +#include "AtomDBSingleton.h" +#include "AtomDB.h" +#include "Utils.h" + +#include "test_utils.h" + +using namespace query_engine; + +string handle_to_atom(const char *handle) { + + shared_ptr db = AtomDBSingleton::get_instance(); + shared_ptr document = db->get_atom_document(handle); + shared_ptr targets = db->query_for_targets((char *) handle); + string answer; + + if (targets != NULL) { + // is link + answer += "<"; + answer += document->get("named_type"); + answer += ": ["; + for (unsigned int i = 0; i < targets->size(); i++) { + answer += handle_to_atom(targets->get_handle(i)); + if (i < (targets->size() - 1)) { + answer += ", "; + } + } + answer += ">"; + } else { + // is node + answer += "("; + answer += document->get("named_type"); + answer += ": "; + answer += document->get("name"); + answer += ")"; + } + + return answer; +} + +void check_query( + vector &query, + unsigned int expected_count, + DASNode *das, + DASNode *requestor, + const string &context) { + + cout << "XXXXXXXXXXXXXXXX DASNode.queries CHECK BEGIN" << endl; + QueryAnswer *query_answer; + RemoteIterator *response = requestor->pattern_matcher_query(query, context); + unsigned int count = 0; + while (! response->finished()) { + while ((query_answer = response->pop()) == NULL) { + if (response->finished()) { + break; + } else { + Utils::sleep(); + } + } + if (query_answer != NULL) { + cout << "XXXXX " << query_answer->to_string() << endl; + //cout << "XXXXX " << handle_to_atom(query_answer->handles[0]) << endl; + count++; + } + } + EXPECT_EQ(count, expected_count); + delete response; + cout << "XXXXXXXXXXXXXXXX DASNode.queries CHECK END" << endl; +} + +TEST(DASNode, queries) { + + cout << "XXXXXXXXXXXXXXXX DASNode.queries BEGIN" << endl; + + setenv("DAS_REDIS_HOSTNAME", "ninjato", 1); + setenv("DAS_REDIS_PORT", "29000", 1); + setenv("DAS_USE_REDIS_CLUSTER", "false", 1); + setenv("DAS_MONGODB_HOSTNAME", "ninjato", 1); + setenv("DAS_MONGODB_PORT", "28000", 1); + setenv("DAS_MONGODB_USERNAME", "dbadmin", 1); + setenv("DAS_MONGODB_PASSWORD", "dassecret", 1); + AtomDBSingleton::init(); + + string das_id = "localhost:31700"; + string requestor_id = "localhost:31701"; + DASNode *das = new DASNode(das_id); + Utils::sleep(1000); + DASNode *requestor = new DASNode(requestor_id, das_id); + Utils::sleep(1000); + + vector q1 = { + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v1", + "VARIABLE", "v2" + }; + + vector q2 = { + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "NODE", "Symbol", "\"human\"", + "VARIABLE", "v1" + }; + + vector q3 = { + "AND", "2", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v1", + "NODE", "Symbol", "\"human\"", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Inheritance", + "VARIABLE", "v1", + "NODE", "Symbol", "\"plant\"" + }; + + vector q4 = { + "AND", "2", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v1", + "VARIABLE", "v2", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v2", + "VARIABLE", "v3" + }; + + vector q5 = { + "OR", "2", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v1", + "NODE", "Symbol", "\"human\"", + "LINK_TEMPLATE", "Expression", "3", + "NODE", "Symbol", "Similarity", + "VARIABLE", "v1", + "NODE", "Symbol", "\"snake\"" + }; + + check_query(q1, 14, das, requestor, "DASNode.queries"); + check_query(q2, 3, das, requestor, "DASNode.queries"); + check_query(q3, 1, das, requestor, "DASNode.queries"); + check_query(q4, 26, das, requestor, "DASNode.queries"); // TODO: FIX THIS count should be == 1 + check_query(q5, 5, das, requestor, "DASNode.queries"); + + //delete(requestor); // TODO: Uncomment this + //delete(das); // TODO: Uncomment this + + cout << "XXXXXXXXXXXXXXXX DASNode.queries END" << endl; +} diff --git a/src/cpp/tests/handle_trie_test.cc b/src/cpp/tests/handle_trie_test.cc new file mode 100644 index 0000000..10d936f --- /dev/null +++ b/src/cpp/tests/handle_trie_test.cc @@ -0,0 +1,553 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "Utils.h" +#include "expression_hasher.h" +#include "HandleTrie.h" +#include "RequestSelector.h" +#include "test_utils.h" + +#define HANDLE_SPACE_SIZE ((unsigned int) 100) + +using namespace attention_broker_server; +using namespace std; + +class TestValue: public HandleTrie::TrieValue { + public: + unsigned int count; + TestValue(int count = 1) { + this->count = count; + } + void merge(TrieValue *other) { + + } +}; + +class AccumulatorValue: public HandleTrie::TrieValue { + public: + unsigned int count; + AccumulatorValue() { + this->count = 1; + } + void merge(TrieValue *other) { + count += ((AccumulatorValue *) other)->count; + } +}; + +char R_TLB[16] = { + '0', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', +}; + +bool visit1(HandleTrie::TrieNode *node, void *data) { + TestValue *value = (TestValue *) node->value; + value->count += *((unsigned int *) data); + return false; +} + +bool visit2(HandleTrie::TrieNode *node, void *data) { + TestValue *value = (TestValue *) node->value; + value->count += 1; + return false; +} + +void visitor3(HandleTrie *trie, unsigned int n) { + for (unsigned int i = 0; i < n; i++) { + trie->traverse(Utils::flip_coin(), &visit2, NULL); + } +} + + +TEST(HandleTrieTest, basics) { + + HandleTrie trie(4); + TestValue *value; + + trie.insert("ABCD", new TestValue(3)); + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value == NULL); + + trie.insert("ABCF", new TestValue(4)); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 4); + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + + trie.insert("ABFD", new TestValue(5)); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 5); + + trie.insert("FBCD", new TestValue(6)); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 6); + + trie.insert("AFCD", new TestValue(7)); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 7); + + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 4); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 5); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 6); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 7); + + value = (TestValue *) trie.lookup("ABFF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("AFCF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("AFFD"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FBCF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FBFD"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FFCD"); + EXPECT_TRUE(value == NULL); + + unsigned int delta = 7; + trie.traverse(false, &visit1, &delta); + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3 + delta); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 4 + delta); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 5 + delta); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 6 + delta); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 7 + delta); +} + +TEST(HandleTrieTest, traverse) { + + HandleTrie trie(4); + TestValue *value; + + trie.insert("ABCD", new TestValue(3)); + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value == NULL); + + trie.insert("ABCF", new TestValue(4)); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 4); + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + + trie.insert("ABFD", new TestValue(5)); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 5); + + trie.insert("FBCD", new TestValue(6)); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 6); + + trie.insert("AFCD", new TestValue(7)); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 7); + + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 3); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 4); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 5); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 6); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value != NULL); + EXPECT_TRUE(value->count == 7); + + value = (TestValue *) trie.lookup("ABFF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("AFCF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("AFFD"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FBCF"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FBFD"); + EXPECT_TRUE(value == NULL); + value = (TestValue *) trie.lookup("FFCD"); + EXPECT_TRUE(value == NULL); + + vector visitors; + unsigned int n_visits = 100000; + unsigned int n_threads = 32; + for (unsigned int i = 0; i < n_threads; i++) { + visitors.push_back(new thread(&visitor3, &trie, n_visits)); + } + for (thread *t: visitors) { + t->join(); + } + + value = (TestValue *) trie.lookup("ABCD"); + EXPECT_TRUE(value->count == 3 + n_visits * n_threads); + value = (TestValue *) trie.lookup("ABCF"); + EXPECT_TRUE(value->count == 4 + n_visits * n_threads); + value = (TestValue *) trie.lookup("ABFD"); + EXPECT_TRUE(value->count == 5 + n_visits * n_threads); + value = (TestValue *) trie.lookup("FBCD"); + EXPECT_TRUE(value->count == 6 + n_visits * n_threads); + value = (TestValue *) trie.lookup("AFCD"); + EXPECT_TRUE(value->count == 7 + n_visits * n_threads); +} + +TEST(HandleTrieTest, merge) { + + HandleTrie trie(4); + + trie.insert("ABCD", new AccumulatorValue()); + trie.insert("ABCF", new AccumulatorValue()); + trie.insert("ABFD", new AccumulatorValue()); + trie.insert("ABFF", new AccumulatorValue()); + trie.insert("AFCD", new AccumulatorValue()); + trie.insert("AFCF", new AccumulatorValue()); + trie.insert("AFFD", new AccumulatorValue()); + trie.insert("AFFF", new AccumulatorValue()); + trie.insert("FBCD", new AccumulatorValue()); + trie.insert("FBCF", new AccumulatorValue()); + trie.insert("FBFD", new AccumulatorValue()); + trie.insert("FBFF", new AccumulatorValue()); + trie.insert("FFCD", new AccumulatorValue()); + trie.insert("FFCF", new AccumulatorValue()); + trie.insert("FFFD", new AccumulatorValue()); + trie.insert("FFFF", new AccumulatorValue()); + + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABCD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABCF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABFD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABFF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFCD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFCF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFFD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFFF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBCD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBCF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBFD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBFF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFCD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFCF"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFFD"))->count == 1); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFFF"))->count == 1); + + trie.insert("ABFF", new AccumulatorValue()); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABFF"))->count == 2); + trie.insert("FFCD", new AccumulatorValue()); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFCD"))->count == 2); + + trie.insert("ABCD", new AccumulatorValue()); + trie.insert("ABCF", new AccumulatorValue()); + trie.insert("ABFD", new AccumulatorValue()); + trie.insert("ABFF", new AccumulatorValue()); + trie.insert("AFCD", new AccumulatorValue()); + trie.insert("AFCF", new AccumulatorValue()); + trie.insert("AFFD", new AccumulatorValue()); + trie.insert("AFFF", new AccumulatorValue()); + trie.insert("FBCD", new AccumulatorValue()); + trie.insert("FBCF", new AccumulatorValue()); + trie.insert("FBFD", new AccumulatorValue()); + trie.insert("FBFF", new AccumulatorValue()); + trie.insert("FFCD", new AccumulatorValue()); + trie.insert("FFCF", new AccumulatorValue()); + trie.insert("FFFD", new AccumulatorValue()); + trie.insert("FFFF", new AccumulatorValue()); + + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABCD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABCF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABFD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("ABFF"))->count == 3); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFCD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFCF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFFD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("AFFF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBCD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBCF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBFD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FBFF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFCD"))->count == 3); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFCF"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFFD"))->count == 2); + EXPECT_TRUE(((AccumulatorValue *) trie.lookup("FFFF"))->count == 2); +} + +TEST(HandleTrieTest, random_stress) { + + char buffer[1000]; + map baseline; + TestValue *value; + + for (unsigned int key_size: {2, 5, 10, 100}) { + baseline.clear(); + HandleTrie *trie = new HandleTrie(key_size); + for (unsigned int i = 0; i < 100000; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + if (baseline.find(s) == baseline.end()) { + baseline[s] = 0; + } + baseline[s] = baseline[s] + 1; + value = (TestValue *) trie->lookup(s); + if (value == NULL) { + value = new TestValue(); + trie->insert(s, value); + } else { + value->count += 1; + } + } + for (auto const& pair : baseline) { + value = (TestValue *) trie->lookup(pair.first); + EXPECT_TRUE(value != NULL); + EXPECT_EQ(pair.second, value->count); + } + delete trie; + } +} + +TEST(HandleTrieTest, hasher) { + char buffer[1000]; + map baseline; + TestValue *value; + + for (unsigned int key_count: {1, 2, 5}) { + baseline.clear(); + unsigned int key_size = (HANDLE_HASH_SIZE - 1) * key_count; + HandleTrie *trie = new HandleTrie(key_size); + for (unsigned int i = 0; i < 100000; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + if (baseline.find(s) == baseline.end()) { + baseline[s] = 0; + } + baseline[s] = baseline[s] + 1; + value = (TestValue *) trie->lookup(s); + if (value == NULL) { + value = new TestValue(); + trie->insert(s, value); + } else { + value->count += 1; + } + } + for (auto const& pair : baseline) { + value = (TestValue *) trie->lookup(pair.first); + EXPECT_EQ(pair.second, value->count); + } + delete trie; + } +} + +TEST(HandleTrieTest, benchmark) { + char buffer[1000]; + map baseline; + TestValue *value; + StopWatch timer_std; + StopWatch timer_trie; + unsigned int n_insertions = 1000000; + + timer_std.start(); + for (unsigned int key_count: {1, 2, 5}) { + unsigned int key_size = (HANDLE_HASH_SIZE - 1) * key_count; + for (unsigned int i = 0; i < n_insertions; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + if (baseline.find(s) == baseline.end()) { + baseline[s] = 0; + } + baseline[s] = baseline[s] + 1; + } + } + timer_std.stop(); + + timer_trie.start(); + for (unsigned int key_count: {1, 2, 5}) { + unsigned int key_size = (HANDLE_HASH_SIZE - 1) * key_count; + HandleTrie *trie = new HandleTrie(key_size); + for (unsigned int i = 0; i < n_insertions; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + value = (TestValue *) trie->lookup(s); + if (value == NULL) { + value = new TestValue(); + trie->insert(s, value); + } else { + value->count += 1; + } + } + } + timer_trie.stop(); + cout << "=======================================================" << endl; + cout << "stdlib: " + timer_std.str_time() << endl; + cout << "trie: " + timer_trie.str_time() << endl; + cout << "=======================================================" << endl; + //EXPECT_EQ(true, false); +} + +void producer(HandleTrie *trie, unsigned int n_insertions) { + char buffer[1000]; + AccumulatorValue *value; + unsigned int key_size = HANDLE_HASH_SIZE - 1; + for (unsigned int i = 0; i < n_insertions; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + value = new AccumulatorValue(); + trie->insert(s, value); + } +} + +void visitor(HandleTrie *trie, unsigned int n_visits) { + char buffer[1000]; + unsigned int key_size = HANDLE_HASH_SIZE - 1; + for (unsigned int i = 0; i < n_visits; i++) { + for (unsigned int j = 0; j < key_size; j++) { + buffer[j] = R_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + trie->lookup(s); + } +} + +void producer2(HandleTrie *trie, unsigned int n_insertions, string *handles) { + AccumulatorValue *value; + for (unsigned int i = 0; i < n_insertions; i++) { + string s = handles[rand() % HANDLE_SPACE_SIZE]; + value = new AccumulatorValue(); + trie->insert(s, value); + } +} + +void visitor2(HandleTrie *trie, unsigned int n_visits, string *handles) { + for (unsigned int i = 0; i < n_visits; i++) { + string s = handles[rand() % HANDLE_SPACE_SIZE]; + trie->lookup(s); + } +} + +TEST(HandleTrieTest, multithread) { + vector producers; + vector visitors; + unsigned int n_insertions = 10000; + unsigned int n_visits = 10000; + StopWatch timer; + timer.start(); + for (int n_producers: {2, 10, 100}) { + for (int n_visitors: {2, 10, 100}) { + unsigned int key_size = HANDLE_HASH_SIZE - 1; + HandleTrie *trie = new HandleTrie(key_size); + producers.clear(); + visitors.clear(); + for (int i = 0; i < n_producers; i++) { + producers.push_back(new thread(&producer, trie, n_insertions)); + } + for (int i = 0; i < n_visitors; i++) { + visitors.push_back(new thread(&visitor, trie, n_visits)); + } + for (thread *t: producers) { + t->join(); + } + for (thread *t: visitors) { + t->join(); + } + delete trie; + } + } + timer.stop(); +} + +TEST(HandleTrieTest, multithread_limited_handle_set) { + string handles[HANDLE_SPACE_SIZE]; + for (unsigned int i = 0; i < HANDLE_SPACE_SIZE; i++) { + handles[i] = random_handle(); + } + vector producers; + vector visitors; + unsigned int n_insertions = 100000; + unsigned int n_visits = 100000; + for (int n_producers: {2, 10, 100}) { + for (int n_visitors: {2, 10, 100}) { + unsigned int key_size = HANDLE_HASH_SIZE - 1; + HandleTrie *trie = new HandleTrie(key_size); + for (int i = 0; i < n_producers; i++) { + producers.push_back(new thread(&producer2, trie, n_insertions, handles)); + } + for (int i = 0; i < n_visitors; i++) { + visitors.push_back(new thread(&visitor2, trie, n_visits, handles)); + } + for (thread *t: producers) { + t->join(); + } + for (thread *t: visitors) { + t->join(); + } + delete trie; + producers.clear(); + visitors.clear(); + } + } +} diff --git a/src/cpp/tests/hebbian_network_test.cc b/src/cpp/tests/hebbian_network_test.cc new file mode 100644 index 0000000..5f5518c --- /dev/null +++ b/src/cpp/tests/hebbian_network_test.cc @@ -0,0 +1,126 @@ +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "Utils.h" +#include "test_utils.h" +#include "HebbianNetwork.h" +#include "expression_hasher.h" + +using namespace attention_broker_server; +using namespace commons; + +TEST(HebbianNetwork, basics) { + + HebbianNetwork network; + + string h1 = prefixed_random_handle("a"); + string h2 = prefixed_random_handle("b"); + string h3 = prefixed_random_handle("d"); + string h4 = prefixed_random_handle("d"); + string h5 = prefixed_random_handle("e"); + + HebbianNetwork::Node *n1 = network.add_node(h1); + HebbianNetwork::Node *n2 = network.add_node(h2); + HebbianNetwork::Node *n3 = network.add_node(h3); + HebbianNetwork::Node *n4 = network.add_node(h4); + + EXPECT_TRUE(network.get_node_count(h1) == 1); + EXPECT_TRUE(network.get_node_count(h2) == 1); + EXPECT_TRUE(network.get_node_count(h3) == 1); + EXPECT_TRUE(network.get_node_count(h4) == 1); + EXPECT_TRUE(network.get_node_count(h5) == 0); + network.add_node(h5); + EXPECT_TRUE(network.get_node_count(h5) == 1); + + network.add_symmetric_edge(h1, h2, n1, n2); + network.add_symmetric_edge(h1, h3, n1, n3); + network.add_symmetric_edge(h1, h4, n1, n4); + network.add_symmetric_edge(h1, h2, n1, n2); + + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h2) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h1) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h3) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h3, h1) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h4) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h4, h1) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h5) == 0); + EXPECT_TRUE(network.get_asymmetric_edge_count(h5, h1) == 0); +} + +TEST(HebbianNetwork, stress) { + + HebbianNetwork network; + StopWatch timer_insertion; + StopWatch timer_lookup; + StopWatch timer_total; + unsigned int handle_space_size = 500; + unsigned int num_insertions = (handle_space_size * 2) * (handle_space_size * 2); + unsigned int num_lookups = 10 * num_insertions; + + string *handles = build_handle_space(handle_space_size); + + timer_insertion.start(); + timer_total.start(); + + for (unsigned int i = 0; i < num_insertions; i++) { + string h1 = handles[rand() % handle_space_size]; + string h2 = handles[rand() % handle_space_size]; + HebbianNetwork::Node *n1 = network.add_node(h1); + HebbianNetwork::Node *n2 = network.add_node(h2); + network.add_symmetric_edge(h1, h2, n1, n2); + } + + timer_insertion.stop(); + timer_lookup.start(); + + for (unsigned int i = 0; i < num_lookups; i++) { + string h1 = handles[rand() % handle_space_size]; + string h2 = handles[rand() % handle_space_size]; + network.get_node_count(h1); + network.get_node_count(h2); + network.get_asymmetric_edge_count(h1, h2); + } + + timer_lookup.stop(); + timer_total.stop(); + + cout << "==================================================================" << endl; + cout << "Insertions: " << timer_insertion.str_time() << endl; + cout << "Lookups: " << timer_lookup.str_time() << endl; + cout << "Total: " << timer_total.str_time() << endl; + cout << "==================================================================" << endl; + //EXPECT_TRUE(false); +} + +TEST(HebbianNetwork, alienate_tokens) { + HebbianNetwork network; + EXPECT_TRUE(network.alienate_tokens() == 1.0); + EXPECT_TRUE(network.alienate_tokens() == 0.0); + EXPECT_TRUE(network.alienate_tokens() == 0.0); +} + +bool visit1(HandleTrie::TrieNode *node, void *data) { + ((HebbianNetwork::Node *) node->value)->importance = 1.0; + return false; +} + +bool visit2( + HandleTrie::TrieNode *node, + HebbianNetwork::Node *source, + forward_list &targets, + unsigned int targets_size, + ImportanceType sum_weights, + void *data) { + + unsigned int fan_max = *((unsigned int *) data); + double stimulus = 1.0 / (double) fan_max; + for (auto target: targets) { + target->importance += stimulus; + source->importance -= stimulus; + } + return false; +} diff --git a/src/cpp/tests/hebbian_network_updater_test.cc b/src/cpp/tests/hebbian_network_updater_test.cc new file mode 100644 index 0000000..5fd4db7 --- /dev/null +++ b/src/cpp/tests/hebbian_network_updater_test.cc @@ -0,0 +1,75 @@ +#include +#include + +#include "gtest/gtest.h" +#include "common.pb.h" +#include "attention_broker.grpc.pb.h" +#include "attention_broker.pb.h" +#include "test_utils.h" +#include "expression_hasher.h" +#include "HebbianNetwork.h" +#include "HebbianNetworkUpdater.h" + +using namespace attention_broker_server; + +TEST(HebbianNetworkUpdater, correlation) { + string *handles = build_handle_space(6); + HebbianNetwork *network = new HebbianNetwork(); + dasproto::HandleList *request; + ExactCountHebbianUpdater *updater = \ + (ExactCountHebbianUpdater *) HebbianNetworkUpdater::factory(HebbianNetworkUpdaterType::EXACT_COUNT); + + request = new dasproto::HandleList(); + request->set_hebbian_network((unsigned long) network); + request->add_list(handles[0]); + request->add_list(handles[1]); + request->add_list(handles[2]); + request->add_list(handles[3]); + updater->correlation(request); + + EXPECT_TRUE(network->get_node_count(handles[0]) == 1); + EXPECT_TRUE(network->get_node_count(handles[1]) == 1); + EXPECT_TRUE(network->get_node_count(handles[2]) == 1); + EXPECT_TRUE(network->get_node_count(handles[3]) == 1); + EXPECT_TRUE(network->get_node_count(handles[4]) == 0); + EXPECT_TRUE(network->get_node_count(handles[5]) == 0); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[1]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[2]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[2]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[2]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[4]) == 0); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[5]) == 0); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[4]) == 0); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[5]) == 0); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[4], handles[5]) == 0); + + request = new dasproto::HandleList(); + request->set_hebbian_network((unsigned long) network); + request->add_list(handles[1]); + request->add_list(handles[2]); + request->add_list(handles[4]); + request->add_list(handles[5]); + updater->correlation(request); + + EXPECT_TRUE(network->get_node_count(handles[0]) == 1); + EXPECT_TRUE(network->get_node_count(handles[1]) == 2); + EXPECT_TRUE(network->get_node_count(handles[2]) == 2); + EXPECT_TRUE(network->get_node_count(handles[3]) == 1); + EXPECT_TRUE(network->get_node_count(handles[4]) == 1); + EXPECT_TRUE(network->get_node_count(handles[5]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[1]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[2]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[0], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[2]) == 2); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[3]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[2]) == 2); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[4]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[1], handles[5]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[4]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[2], handles[5]) == 1); + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[4], handles[5]) == 1); +} diff --git a/src/cpp/tests/iterator_test.cc b/src/cpp/tests/iterator_test.cc new file mode 100644 index 0000000..73a767e --- /dev/null +++ b/src/cpp/tests/iterator_test.cc @@ -0,0 +1,123 @@ +#include +#include "gtest/gtest.h" + +#include "QueryNode.h" +#include "LinkTemplate.h" +#include "AtomDBSingleton.h" +#include "test_utils.h" +#include "Iterator.h" + +using namespace query_engine; +using namespace query_element; +using namespace query_node; + +class TestQueryElement : public QueryElement { + public: + TestQueryElement(const string &id) { + this->id = id; + } + void setup_buffers() {} + void graceful_shutdown() {} +}; + +TEST(Iterator, basics) { + string client_id = "no_query_element"; + TestQueryElement dummy(client_id); + Iterator query_answer_iterator(&dummy); + string server_id = query_answer_iterator.id; + EXPECT_FALSE(server_id == ""); + QueryNodeClient client_node(client_id, query_answer_iterator.id); + + EXPECT_FALSE(query_answer_iterator.finished()); + + QueryAnswer *qa; + QueryAnswer qa0("h0", 0.0); + QueryAnswer qa1("h1", 0.1); + QueryAnswer qa2("h2", 0.2); + + client_node.add_query_answer(&qa0); + client_node.add_query_answer(&qa1); + Utils::sleep(1000); + + EXPECT_FALSE(query_answer_iterator.finished()); + qa = query_answer_iterator.pop(); + EXPECT_TRUE(strcmp(qa->handles[0], "h0") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.0)); + + EXPECT_FALSE(query_answer_iterator.finished()); + qa = query_answer_iterator.pop(); + EXPECT_TRUE(strcmp(qa->handles[0], "h1") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.1)); + + qa = query_answer_iterator.pop(); + EXPECT_TRUE(qa == NULL); + EXPECT_FALSE(query_answer_iterator.finished()); + + client_node.add_query_answer(&qa2); + EXPECT_FALSE(client_node.is_query_answers_finished()); + EXPECT_FALSE(query_answer_iterator.finished()); + client_node.query_answers_finished(); + EXPECT_TRUE(client_node.is_query_answers_finished()); + EXPECT_FALSE(query_answer_iterator.finished()); + Utils::sleep(1000); + + EXPECT_FALSE(query_answer_iterator.finished()); + qa = query_answer_iterator.pop(); + EXPECT_TRUE(strcmp(qa->handles[0], "h2") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.2)); + EXPECT_TRUE(query_answer_iterator.finished()); +} + +TEST(Iterator, link_template_integration) { + + setenv("DAS_REDIS_HOSTNAME", "ninjato", 1); + setenv("DAS_REDIS_PORT", "29000", 1); + setenv("DAS_USE_REDIS_CLUSTER", "false", 1); + setenv("DAS_MONGODB_HOSTNAME", "ninjato", 1); + setenv("DAS_MONGODB_PORT", "28000", 1); + setenv("DAS_MONGODB_USERNAME", "dbadmin", 1); + setenv("DAS_MONGODB_PASSWORD", "dassecret", 1); + + AtomDBSingleton::init(); + string expression = "Expression"; + string symbol = "Symbol"; + + Variable v1("v1"); + Variable v2("v2"); + Variable v3("v3"); + Node similarity(symbol, "Similarity"); + Node human(symbol, "\"human\""); + + LinkTemplate<3> link_template("Expression", {&similarity, &human, &v1}); + Iterator query_answer_iterator(&link_template); + + string monkey_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"monkey\"")); + string chimp_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"chimp\"")); + string ent_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"ent\"")); + bool monkey_flag = false; + bool chimp_flag = false; + bool ent_flag = false; + QueryAnswer *query_answer; + while (! query_answer_iterator.finished()) { + query_answer = query_answer_iterator.pop(); + if (query_answer != NULL) { + string var = string(query_answer->assignment.get("v1")); + //EXPECT_TRUE(double_equals(query_answer->importance, 0.0)); + if (var == monkey_handle) { + // TODO: perform extra checks + monkey_flag = true; + } else if (var == chimp_handle) { + // TODO: perform extra checks + chimp_flag = true; + } else if (var == ent_handle) { + // TODO: perform extra checks + ent_flag = true; + } else { + FAIL(); + } + } + } + EXPECT_TRUE(monkey_flag); + EXPECT_TRUE(chimp_flag); + EXPECT_TRUE(ent_flag); +} diff --git a/src/cpp/tests/link_template_test.cc b/src/cpp/tests/link_template_test.cc new file mode 100644 index 0000000..1d7f50f --- /dev/null +++ b/src/cpp/tests/link_template_test.cc @@ -0,0 +1,67 @@ +#include +#include "gtest/gtest.h" + +#include "QueryNode.h" +#include "LinkTemplate.h" +#include "AtomDBSingleton.h" +#include "test_utils.h" + +using namespace query_engine; +using namespace query_element; + +TEST(LinkTemplate, basics) { + + setenv("DAS_REDIS_HOSTNAME", "ninjato", 1); + setenv("DAS_REDIS_PORT", "29000", 1); + setenv("DAS_USE_REDIS_CLUSTER", "false", 1); + setenv("DAS_MONGODB_HOSTNAME", "ninjato", 1); + setenv("DAS_MONGODB_PORT", "28000", 1); + setenv("DAS_MONGODB_USERNAME", "dbadmin", 1); + setenv("DAS_MONGODB_PASSWORD", "dassecret", 1); + + string server_node_id = "SERVER"; + QueryNodeServer server_node(server_node_id); + + AtomDBSingleton::init(); + string expression = "Expression"; + string symbol = "Symbol"; + + Variable v1("v1"); + Variable v2("v2"); + Variable v3("v3"); + Node similarity(symbol, "Similarity"); + Node human(symbol, "\"human\""); + + LinkTemplate<3> link_template1("Expression", {&similarity, &human, &v1}); + link_template1.subsequent_id = server_node_id; + link_template1.setup_buffers(); + //link_template1.fetch_links(); + Utils::sleep(1000); + + string monkey_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"monkey\"")); + string chimp_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"chimp\"")); + string ent_handle = string(terminal_hash((char *) symbol.c_str(), (char *) "\"ent\"")); + bool monkey_flag = false; + bool chimp_flag = false; + bool ent_flag = false; + QueryAnswer *query_answer; + while ((query_answer = server_node.pop_query_answer()) != NULL) { + string var = string(query_answer->assignment.get("v1")); + //EXPECT_TRUE(double_equals(query_answer->importance, 0.0)); + if (var == monkey_handle) { + // TODO: perform extra checks + monkey_flag = true; + } else if (var == chimp_handle) { + // TODO: perform extra checks + chimp_flag = true; + } else if (var == ent_handle) { + // TODO: perform extra checks + ent_flag = true; + } else { + FAIL(); + } + } + EXPECT_TRUE(monkey_flag); + EXPECT_TRUE(chimp_flag); + EXPECT_TRUE(ent_flag); +} diff --git a/src/cpp/tests/nested_link_template_test.cc b/src/cpp/tests/nested_link_template_test.cc new file mode 100644 index 0000000..be460e9 --- /dev/null +++ b/src/cpp/tests/nested_link_template_test.cc @@ -0,0 +1,81 @@ +#include +#include "gtest/gtest.h" + +#include "QueryNode.h" +#include "AtomDB.h" +#include "LinkTemplate.h" +#include "AtomDBSingleton.h" +#include "test_utils.h" + +using namespace query_engine; +using namespace query_element; + +TEST(LinkTemplate, basics) { + + setenv("DAS_REDIS_HOSTNAME", "ninjato", 1); + setenv("DAS_REDIS_PORT", "29000", 1); + setenv("DAS_USE_REDIS_CLUSTER", "false", 1); + setenv("DAS_MONGODB_HOSTNAME", "ninjato", 1); + setenv("DAS_MONGODB_PORT", "28000", 1); + setenv("DAS_MONGODB_USERNAME", "dbadmin", 1); + setenv("DAS_MONGODB_PASSWORD", "dassecret", 1); + + AtomDBSingleton::init(); + string expression = "Expression"; + string symbol = "Symbol"; + + Variable v1("v1"); + Variable v2("v2"); + Variable v3("v3"); + Node similarity(symbol, "Similarity"); + Node odd_link(symbol, "OddLink"); + + LinkTemplate<3> inner_template(expression, {&similarity, &v1, &v2}); + LinkTemplate<2> outter_template(expression, {&odd_link, &inner_template}); + + Iterator iterator(&outter_template); + + QueryAnswer *query_answer; + unsigned int count = 0; + while (! iterator.finished()) { + if ((query_answer = iterator.pop()) == NULL) { + Utils::sleep(); + } else { + EXPECT_TRUE(double_equals(query_answer->importance, 0.0)); + count++; + } + } + EXPECT_EQ(count, 8); +} + +TEST(LinkTemplate, nested_variables) { + + string expression = "Expression"; + string symbol = "Symbol"; + + Variable v1("v1"); + Variable v2("v2"); + Variable v3("v3"); + Node similarity(symbol, "Similarity"); + Node odd_link(symbol, "OddLink"); + Node human(symbol, "\"human\""); + + LinkTemplate<3> inner_template(expression, {&similarity, &v1, &v2}); + LinkTemplate<2> outter_template(expression, {&odd_link, &inner_template}); + LinkTemplate<3> human_template(expression, {&similarity, &v1, &human}); + And<2> and_operator({&human_template, &outter_template}); + + Iterator iterator(&and_operator); + + QueryAnswer *query_answer; + unsigned int count = 0; + while (! iterator.finished()) { + if ((query_answer = iterator.pop()) == NULL) { + Utils::sleep(); + } else { + //EXPECT_TRUE(double_equals(query_answer->importance, 0.0)); + count++; + } + } + EXPECT_EQ(count, 1); +} diff --git a/src/cpp/tests/query_answer_test.cc b/src/cpp/tests/query_answer_test.cc new file mode 100644 index 0000000..4ff839d --- /dev/null +++ b/src/cpp/tests/query_answer_test.cc @@ -0,0 +1,148 @@ +#include +#include +#include "gtest/gtest.h" + +#include "QueryAnswer.h" +#include "Utils.h" +#include "test_utils.h" + +using namespace query_engine; +using namespace commons; + +TEST(QueryAnswer, assignments_basics) { + + Assignment mapping0; + + // Tests assign() + Assignment mapping1; + EXPECT_TRUE(mapping1.assign("v1", "1")); + EXPECT_TRUE(mapping1.assign("v2", "2")); + EXPECT_TRUE(mapping1.assign("v2", "2")); + EXPECT_FALSE(mapping1.assign("v2", "3")); + Assignment mapping2; + EXPECT_TRUE(mapping1.assign("v1", "1")); + EXPECT_TRUE(mapping1.assign("v3", "3")); + Assignment mapping3; + EXPECT_TRUE(mapping3.assign("v1", "1")); + EXPECT_TRUE(mapping3.assign("v2", "3")); + + // Tests get() + EXPECT_TRUE(strcmp(mapping1.get("v1"), "1") == 0); + EXPECT_TRUE(strcmp(mapping1.get("v2"), "2") == 0); + EXPECT_TRUE(mapping1.get("blah") == NULL); + EXPECT_TRUE(mapping1.get("v11") == NULL); + EXPECT_TRUE(mapping1.get("v") == NULL); + EXPECT_TRUE(mapping1.get("") == NULL); + + // Tests is_compatible() + EXPECT_TRUE(mapping1.is_compatible(mapping0)); + EXPECT_TRUE(mapping2.is_compatible(mapping0)); + EXPECT_TRUE(mapping3.is_compatible(mapping0)); + EXPECT_TRUE(mapping0.is_compatible(mapping1)); + EXPECT_TRUE(mapping0.is_compatible(mapping2)); + EXPECT_TRUE(mapping0.is_compatible(mapping3)); + EXPECT_TRUE(mapping1.is_compatible(mapping2)); + EXPECT_TRUE(mapping2.is_compatible(mapping1)); + EXPECT_TRUE(mapping2.is_compatible(mapping3)); + EXPECT_TRUE(mapping3.is_compatible(mapping2)); + EXPECT_FALSE(mapping1.is_compatible(mapping3)); + EXPECT_FALSE(mapping3.is_compatible(mapping1)); + + // Tests copy_from() + Assignment mapping4; + mapping4.copy_from(mapping1); + EXPECT_TRUE(strcmp(mapping4.get("v1"), "1") == 0); + EXPECT_TRUE(strcmp(mapping4.get("v2"), "2") == 0); + EXPECT_TRUE(mapping4.is_compatible(mapping2)); + EXPECT_TRUE(mapping2.is_compatible(mapping4)); + EXPECT_FALSE(mapping4.is_compatible(mapping3)); + EXPECT_FALSE(mapping3.is_compatible(mapping4)); + + // Tests add_assignments() + mapping4.add_assignments(mapping1); + mapping4.add_assignments(mapping2); + EXPECT_TRUE(strcmp(mapping4.get("v1"), "1") == 0); + EXPECT_TRUE(strcmp(mapping4.get("v2"), "2") == 0); + EXPECT_TRUE(strcmp(mapping4.get("v3"), "3") == 0); + EXPECT_TRUE(mapping1.is_compatible(mapping4)); + EXPECT_TRUE(mapping2.is_compatible(mapping4)); + EXPECT_FALSE(mapping3.is_compatible(mapping4)); + EXPECT_TRUE(mapping4.is_compatible(mapping1)); + EXPECT_TRUE(mapping4.is_compatible(mapping2)); + EXPECT_FALSE(mapping4.is_compatible(mapping3)); + + // Tests to_string(): + EXPECT_TRUE(mapping0.to_string() != ""); + EXPECT_TRUE(mapping1.to_string() != ""); + EXPECT_TRUE(mapping4.to_string() != ""); +} + +TEST(QueryAnswer, query_answer_basics) { + + // Tests add_handle() + QueryAnswer query_answer1("h1", 0); + query_answer1.assignment.assign("v1", "1"); + EXPECT_EQ(query_answer1.handles_size, 1); + EXPECT_TRUE(strcmp(query_answer1.handles[0], "h1") == 0); + query_answer1.add_handle("hx"); + EXPECT_EQ(query_answer1.handles_size, 2); + EXPECT_TRUE(strcmp(query_answer1.handles[0], "h1") == 0); + EXPECT_TRUE(strcmp(query_answer1.handles[1], "hx") == 0); + + // Tests merge() + QueryAnswer query_answer2("h2", 0); + query_answer2.assignment.assign("v2", "2"); + query_answer2.add_handle("hx"); + query_answer2.merge(&query_answer1); + EXPECT_EQ(query_answer2.handles_size, 3); + EXPECT_TRUE(strcmp(query_answer2.handles[0], "h2") == 0); + EXPECT_TRUE(strcmp(query_answer2.handles[1], "hx") == 0); + EXPECT_TRUE(strcmp(query_answer2.handles[2], "h1") == 0); + EXPECT_FALSE(query_answer2.assignment.assign("v1", "x")); + EXPECT_FALSE(query_answer2.assignment.assign("v2", "x")); + EXPECT_TRUE(query_answer2.assignment.assign("v3", "x")); + + // Tests copy() + QueryAnswer *query_answer3 = QueryAnswer::copy(&query_answer2); + EXPECT_EQ(query_answer3->handles_size, 3); + EXPECT_TRUE(strcmp(query_answer3->handles[0], "h2") == 0); + EXPECT_TRUE(strcmp(query_answer3->handles[1], "hx") == 0); + EXPECT_TRUE(strcmp(query_answer3->handles[2], "h1") == 0); + EXPECT_FALSE(query_answer3->assignment.assign("v1", "x")); + EXPECT_FALSE(query_answer3->assignment.assign("v2", "x")); + EXPECT_FALSE(query_answer3->assignment.assign("v3", "y")); + EXPECT_TRUE(query_answer3->assignment.assign("v4", "x")); +} + +void query_answers_equal(QueryAnswer *qa1, QueryAnswer *qa2) { + EXPECT_TRUE(double_equals(qa1->importance, qa2->importance)); + EXPECT_EQ(qa1->to_string(), qa2->to_string()); +} + +TEST(QueryAnswer, tokenization) { + + unsigned int NUM_TESTS = 100000; + unsigned int MAX_HANDLES = 5; + unsigned int MAX_ASSIGNMENTS = 10; + + for (unsigned int test = 0; test < NUM_TESTS; test++) { + unsigned int num_handles = (rand() % MAX_HANDLES) + 1; + unsigned int num_assignments = (rand() % MAX_ASSIGNMENTS); + QueryAnswer input(Utils::flip_coin() ? 1 : 0); + for (unsigned int i = 0; i < num_handles; i++) { + input.add_handle(strdup(random_handle().c_str())); + } + unsigned int label_count = 0; + for (unsigned int i = 0; i < num_assignments; i++) { + input.assignment.assign( + strdup(sequential_label(label_count).c_str()), + strdup(random_handle().c_str())); + } + + query_answers_equal(&input, QueryAnswer::copy(&input)); + string token_string = input.tokenize(); + QueryAnswer output(0.0); + output.untokenize(token_string); + query_answers_equal(&input, &output); + } +} diff --git a/src/cpp/tests/query_node_test.cc b/src/cpp/tests/query_node_test.cc new file mode 100644 index 0000000..5c08c4f --- /dev/null +++ b/src/cpp/tests/query_node_test.cc @@ -0,0 +1,69 @@ +#include "gtest/gtest.h" + +#include "Utils.h" +#include "QueryNode.h" +#include "QueryAnswer.h" + +using namespace commons; +using namespace query_node; + +TEST(QueryNode, basics) { + + string server_id = "server"; + string client1_id = "client1"; + string client2_id = "client2"; + + QueryNodeServer server(server_id); + QueryNodeClient client1(client1_id, server_id); + QueryNodeClient client2(client2_id, server_id); + + EXPECT_TRUE(server.is_query_answers_empty()); + EXPECT_FALSE(server.is_query_answers_finished()); + EXPECT_TRUE(client1.is_query_answers_empty()); + EXPECT_FALSE(client1.is_query_answers_finished()); + EXPECT_TRUE(client2.is_query_answers_empty()); + EXPECT_FALSE(client2.is_query_answers_finished()); + + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 0); + + client1.add_query_answer((QueryAnswer *) 1); + client1.add_query_answer((QueryAnswer *) 2); + Utils::sleep(1000); + client2.add_query_answer((QueryAnswer *) 3); + client2.add_query_answer((QueryAnswer *) 4); + Utils::sleep(1000); + client1.add_query_answer((QueryAnswer *) 5); + Utils::sleep(1000); + + EXPECT_FALSE(server.is_query_answers_empty()); + EXPECT_FALSE(server.is_query_answers_finished()); + EXPECT_TRUE(client1.is_query_answers_empty()); + EXPECT_FALSE(client1.is_query_answers_finished()); + EXPECT_TRUE(client2.is_query_answers_empty()); + EXPECT_FALSE(client2.is_query_answers_finished()); + + client1.query_answers_finished(); + client2.query_answers_finished(); + Utils::sleep(1000); + + EXPECT_FALSE(server.is_query_answers_empty()); + EXPECT_TRUE(server.is_query_answers_finished()); + EXPECT_TRUE(client1.is_query_answers_empty()); + EXPECT_TRUE(client1.is_query_answers_finished()); + EXPECT_TRUE(client2.is_query_answers_empty()); + EXPECT_TRUE(client2.is_query_answers_finished()); + + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 1); + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 2); + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 3); + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 4); + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 5); + ASSERT_TRUE(server.pop_query_answer() == (QueryAnswer *) 0); + + EXPECT_TRUE(server.is_query_answers_empty()); + EXPECT_TRUE(server.is_query_answers_finished()); + EXPECT_TRUE(client1.is_query_answers_empty()); + EXPECT_TRUE(client1.is_query_answers_finished()); + EXPECT_TRUE(client2.is_query_answers_empty()); + EXPECT_TRUE(client2.is_query_answers_finished()); +} diff --git a/src/cpp/tests/remote_sink_iterator_test.cc b/src/cpp/tests/remote_sink_iterator_test.cc new file mode 100644 index 0000000..311169e --- /dev/null +++ b/src/cpp/tests/remote_sink_iterator_test.cc @@ -0,0 +1,72 @@ +#include +#include "gtest/gtest.h" + +#include "RemoteSink.h" +#include "Source.h" +#include "RemoteIterator.h" +#include "AtomDBSingleton.h" +#include "test_utils.h" + +using namespace query_engine; +using namespace query_element; +using namespace query_node; + +class TestSource : public Source { + public: + TestSource(const string &id) { + this->id = id; + } + void add(QueryAnswer *qa) { + this->output_buffer->add_query_answer(qa); + Utils::sleep(1000); + } + void finished() { + this->output_buffer->query_answers_finished(); + Utils::sleep(1000); + } +}; + +TEST(RemoteSinkIterator, basics) { + string consumer_id = "localhost:30700"; + string producer_id = "localhost:30701"; + + string input_element_id = "test_source"; + TestSource input(input_element_id); + RemoteIterator consumer(consumer_id); + RemoteSink producer(&input, producer_id, consumer_id); + Utils::sleep(1000); + + EXPECT_FALSE(consumer.finished()); + + QueryAnswer *qa; + QueryAnswer qa0("h0", 0.0); + QueryAnswer qa1("h1", 0.1); + QueryAnswer qa2("h2", 0.2); + + input.add(&qa0); + input.add(&qa1); + + EXPECT_FALSE(consumer.finished()); + EXPECT_FALSE((qa = consumer.pop()) == NULL); + EXPECT_TRUE(strcmp(qa->handles[0], "h0") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.0)); + + EXPECT_FALSE(consumer.finished()); + EXPECT_FALSE((qa = consumer.pop()) == NULL); + EXPECT_TRUE(strcmp(qa->handles[0], "h1") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.1)); + + EXPECT_TRUE((qa = consumer.pop()) == NULL); + EXPECT_FALSE(consumer.finished()); + + input.add(&qa2); + input.finished(); + EXPECT_FALSE(consumer.finished()); + + EXPECT_FALSE(consumer.finished()); + EXPECT_FALSE((qa = consumer.pop()) == NULL); + EXPECT_TRUE(strcmp(qa->handles[0], "h2") == 0); + EXPECT_TRUE(double_equals(qa->importance, 0.2)); + Utils::sleep(5000); // XXXXXXXXXXXXXXXXXXXXXXXXXXX + EXPECT_TRUE(consumer.finished()); +} diff --git a/src/cpp/tests/request_selector_test.cc b/src/cpp/tests/request_selector_test.cc new file mode 100644 index 0000000..2376d5d --- /dev/null +++ b/src/cpp/tests/request_selector_test.cc @@ -0,0 +1,46 @@ +#include "gtest/gtest.h" + +#include "Utils.h" +#include "SharedQueue.h" +#include "RequestSelector.h" + +using namespace attention_broker_server; + +class TestMessage { + public: + int message; + TestMessage(int n) { + message = n; + } +}; + +TEST(RequestSelectorTest, even_thread_count) { + + SharedQueue *stimulus = new SharedQueue(1); + SharedQueue *correlation = new SharedQueue(1); + + RequestSelector *selector0 = RequestSelector::factory( + SelectorType::EVEN_THREAD_COUNT, + 0, + stimulus, + correlation); + RequestSelector *selector1 = RequestSelector::factory( + SelectorType::EVEN_THREAD_COUNT, + 1, + stimulus, + correlation); + + pair request; + for (int i = 0; i < 1000; i++) { + if (Utils::flip_coin()) { + request = selector0->next(); + EXPECT_EQ(request.first, RequestType::STIMULUS); + } else { + request = selector1->next(); + EXPECT_EQ(request.first, RequestType::CORRELATION); + } + } + + delete stimulus; + delete correlation; +} diff --git a/src/cpp/tests/shared_queue_test.cc b/src/cpp/tests/shared_queue_test.cc new file mode 100644 index 0000000..651e37d --- /dev/null +++ b/src/cpp/tests/shared_queue_test.cc @@ -0,0 +1,74 @@ +#include "gtest/gtest.h" + +#include "AttentionBrokerServer.h" + +using namespace attention_broker_server; + +class TestSharedQueue: public SharedQueue { + public: + TestSharedQueue(unsigned int n) : SharedQueue(n) { + } + unsigned int test_current_size() { + return current_size(); + } + unsigned int test_current_start() { + return current_start(); + } + unsigned int test_current_end() { + return current_end(); + } +}; + +class TestMessage { + public: + int message; + TestMessage(int n) { + message = n; + } +}; + +TEST(SharedQueueTest, basics) { + + dasproto::Empty empty; + dasproto::Ack ack; + + TestSharedQueue q1((unsigned int) 5); + EXPECT_TRUE(q1.test_current_size() == 5); + q1.enqueue((void *) "1"); + EXPECT_EQ((char *) q1.dequeue(), "1"); + q1.enqueue((void *) "2"); + q1.enqueue((void *) "3"); + q1.enqueue((void *) "4"); + q1.enqueue((void *) "5"); + EXPECT_TRUE(q1.test_current_size() == 5); + EXPECT_EQ((char *) q1.dequeue(), "2"); + q1.enqueue((void *) "6"); + q1.enqueue((void *) "7"); + EXPECT_TRUE(q1.test_current_size() == 5); + EXPECT_EQ((char *) q1.dequeue(), "3"); + q1.enqueue((void *) "8"); + EXPECT_EQ((char *) q1.dequeue(), "4"); + q1.enqueue((void *) "9"); + EXPECT_EQ((char *) q1.dequeue(), "5"); + q1.enqueue((void *) "10"); + EXPECT_TRUE(q1.test_current_size() == 5); + EXPECT_EQ((char *) q1.dequeue(), "6"); + EXPECT_EQ((char *) q1.dequeue(), "7"); + q1.enqueue((void *) "11"); + q1.enqueue((void *) "12"); + EXPECT_TRUE(q1.test_current_size() == 5); + EXPECT_TRUE(q1.test_current_start() == 2); + EXPECT_TRUE(q1.test_current_end() == 2); + q1.enqueue((void *) "13"); + EXPECT_TRUE(q1.test_current_size() == 10); + EXPECT_TRUE(q1.test_current_start() == 0); + EXPECT_TRUE(q1.test_current_end() == 6); + q1.enqueue((void *) "14"); + EXPECT_EQ((char *) q1.dequeue(), "8"); + EXPECT_EQ((char *) q1.dequeue(), "9"); + EXPECT_EQ((char *) q1.dequeue(), "10"); + EXPECT_EQ((char *) q1.dequeue(), "11"); + EXPECT_EQ((char *) q1.dequeue(), "12"); + EXPECT_EQ((char *) q1.dequeue(), "13"); + EXPECT_EQ((char *) q1.dequeue(), "14"); +} diff --git a/src/cpp/tests/stimulus_spreader_test.cc b/src/cpp/tests/stimulus_spreader_test.cc new file mode 100644 index 0000000..6e99f89 --- /dev/null +++ b/src/cpp/tests/stimulus_spreader_test.cc @@ -0,0 +1,229 @@ +#include +#include + +#include "gtest/gtest.h" +#include "common.pb.h" +#include "attention_broker.grpc.pb.h" +#include "attention_broker.pb.h" +#include "test_utils.h" +#include "expression_hasher.h" +#include "AttentionBrokerServer.h" +#include "HebbianNetwork.h" +#include "HebbianNetworkUpdater.h" +#include "StimulusSpreader.h" + +using namespace attention_broker_server; + +bool importance_equals(ImportanceType importance, double v2) { + double v1 = (double) importance; + return fabs(v2 - v1) < 0.001; +} + +TEST(TokenSpreader, distribute_wages) { + + unsigned int num_tests = 10000; + unsigned int total_nodes = 100; + + TokenSpreader *spreader; + ImportanceType tokens_to_spread; + dasproto::HandleCount *request; + TokenSpreader::StimuliData data; + + for (unsigned int i = 0; i < num_tests; i++) { + string *handles = build_handle_space(total_nodes); + spreader = (TokenSpreader *) StimulusSpreader::factory(StimulusSpreaderType::TOKEN); + + tokens_to_spread = 1.0; + request = new dasproto::HandleCount(); + (*request->mutable_map())[handles[0]] = 2; + (*request->mutable_map())[handles[1]] = 1; + (*request->mutable_map())[handles[2]] = 2; + (*request->mutable_map())[handles[3]] = 1; + (*request->mutable_map())[handles[4]] = 2; + (*request->mutable_map())["SUM"] = 8; + data.importance_changes = new HandleTrie(HANDLE_HASH_SIZE - 1); + spreader->distribute_wages(request, tokens_to_spread, &data); + + EXPECT_TRUE(importance_equals(((TokenSpreader::ImportanceChanges *) data.importance_changes->lookup(handles[0]))->wages, 0.250)); + EXPECT_TRUE(importance_equals(((TokenSpreader::ImportanceChanges *) data.importance_changes->lookup(handles[1]))->wages, 0.125)); + EXPECT_TRUE(importance_equals(((TokenSpreader::ImportanceChanges *) data.importance_changes->lookup(handles[2]))->wages, 0.250)); + EXPECT_TRUE(importance_equals(((TokenSpreader::ImportanceChanges *) data.importance_changes->lookup(handles[3]))->wages, 0.125)); + EXPECT_TRUE(importance_equals(((TokenSpreader::ImportanceChanges *) data.importance_changes->lookup(handles[4]))->wages, 0.250)); + EXPECT_TRUE(data.importance_changes->lookup(handles[5]) == NULL); + EXPECT_TRUE(data.importance_changes->lookup(handles[6]) == NULL); + EXPECT_TRUE(data.importance_changes->lookup(handles[7]) == NULL); + EXPECT_TRUE(data.importance_changes->lookup(handles[8]) == NULL); + EXPECT_TRUE(data.importance_changes->lookup(handles[9]) == NULL); + + delete spreader; + } +} + +static HebbianNetwork *build_test_network(string *handles) { + + HebbianNetwork *network = new HebbianNetwork(); + dasproto::HandleList *request; + ExactCountHebbianUpdater *updater = \ + (ExactCountHebbianUpdater *) HebbianNetworkUpdater::factory(HebbianNetworkUpdaterType::EXACT_COUNT); + + request = new dasproto::HandleList(); + request->set_hebbian_network((unsigned long) network); + request->add_list(handles[0]); + request->add_list(handles[1]); + request->add_list(handles[2]); + request->add_list(handles[3]); + updater->correlation(request); + + request = new dasproto::HandleList(); + request->set_hebbian_network((unsigned long) network); + request->add_list(handles[1]); + request->add_list(handles[2]); + request->add_list(handles[4]); + request->add_list(handles[5]); + updater->correlation(request); + + return network; +} + +TEST(TokenSpreader, spread_stimuli) { + + // -------------------------------------------------------------------- + // NOTE TO REVIEWER: I left debug messages because this code extremely + // error prone and difficult to debug. Probably we'll + // need to return to this test to make it pass when + // we make changes in the tested code. + // -------------------------------------------------------------------- + // Build and check network + + string *handles = build_handle_space(6, true); + for (unsigned int i = 0; i < 6; i++) { + cout << i << ": " << handles[i] << endl; + } + + HebbianNetwork *network = build_test_network(handles); + + unsigned int expected[6][6] = { + {0, 1, 1, 1, 0, 0}, + {1, 0, 2, 1, 1, 1}, + {1, 2, 0, 1, 1, 1}, + {1, 1, 1, 0, 0, 0}, + {0, 1, 1, 0, 0, 1}, + {0, 1, 1, 0, 1, 0}, + }; + for (unsigned int i = 0; i < 6; i++) { + EXPECT_TRUE(importance_equals(network->get_node_importance(handles[i]), 0.0000)); + if (i == 1 || i == 2) { + EXPECT_TRUE(network->get_node_count(handles[i]) == 2); + } else { + EXPECT_TRUE(network->get_node_count(handles[i]) == 1); + } + for (unsigned int j = 0; j < 6; j++) { + cout << i << ", " << j << ": " << expected[i][j] << " " << network->get_asymmetric_edge_count(handles[i], handles[j]) << endl; + EXPECT_TRUE(network->get_asymmetric_edge_count(handles[i], handles[j]) == expected[i][j]); + } + } + + // ---------------------------------------------------------- + // Build and process simulus spreading request + + dasproto::HandleCount *request; + TokenSpreader *spreader = \ + (TokenSpreader *) StimulusSpreader::factory(StimulusSpreaderType::TOKEN); + + request = new dasproto::HandleCount(); + request->set_hebbian_network((unsigned long) network); + (*request->mutable_map())[handles[0]] = 1; + (*request->mutable_map())[handles[1]] = 1; + (*request->mutable_map())[handles[2]] = 1; + (*request->mutable_map())[handles[3]] = 1; + (*request->mutable_map())[handles[4]] = 1; + (*request->mutable_map())[handles[5]] = 1; + (*request->mutable_map())["SUM"] = 6; + unsigned int SUM = (*request->mutable_map())["SUM"]; + spreader->spread_stimuli(request); + + // ---------------------------------------------------------- + // Compute expected value for importance of each node + + unsigned int arity[6]; + arity[0] = 3; + arity[1] = 5; + arity[2] = 5; + arity[3] = 3; + arity[4] = 3; + arity[5] = 3; + unsigned int max_arity = 5; + + double base_importance = (double) 1 / 6; + double rent = base_importance * AttentionBrokerServer::RENT_RATE; + double total_rent = rent * 6; + double total_wages = total_rent; + + cout << "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" << endl; + cout << "XXX Expected total rent: " << total_rent << endl; + cout << "XXX Expected rent rate: " << AttentionBrokerServer::RENT_RATE << endl; + + double wages[6]; + for (unsigned int i = 0; i < 6; i++) { + wages[i] = ((double) ((*request->mutable_map())[handles[i]])) / SUM * total_wages; + } + + double updated[6]; + for (unsigned int i = 0; i < 6; i++) { + updated[i] = base_importance + wages[i] - rent; + } + + double to_spread[6]; + for (unsigned int i = 0; i < 6; i++) { + double arity_ratio = (double) arity[i] / max_arity; + double lb = AttentionBrokerServer::SPREADING_RATE_LOWERBOUND; + double ub = AttentionBrokerServer::SPREADING_RATE_UPPERBOUND; + double spreading_rate = lb + (arity_ratio * (ub - lb)); + to_spread[i] = updated[i] * spreading_rate; + cout << "XXX Total to spread: " << to_spread[i] << endl; + } + + double sum_weight[6] = {3.0, 3.0, 3.0, 3.0, 3.0, 3.0}; + double weight[6][6] = { + {0.0, 1.0, 1.0, 1.0, 0.0, 0.0}, + {0.5, 0.0, 1.0, 0.5, 0.5, 0.5}, + {0.5, 1.0, 0.0, 0.5, 0.5, 0.5}, + {1.0, 1.0, 1.0, 0.0, 0.0, 0.0}, + {0.0, 1.0, 1.0, 0.0, 0.0, 1.0}, + {0.0, 1.0, 1.0, 0.0, 1.0, 0.0}, + }; + for (unsigned int i = 0; i < 6; i++) { + for (unsigned int j = 0; j < 6; j++) { + if (i != j) { + cout << "XXX weight[" << i << "][" << j << "]: " << weight[i][j] << endl; + } + } + cout << "XXX sum_weight[" << i << "]: " << sum_weight[i] << endl; + } + + double received[6] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + for (unsigned int i = 0; i < 6; i++) { + for (unsigned int j = 0; j < 6; j++) { + if (i != j) { + double weight_ratio = weight[i][j] / sum_weight[i]; + double stimulus = weight_ratio * to_spread[i]; + cout << "XXX stimulus[" << i << "][" << j << "]: " << stimulus << endl; + received[j] += stimulus; + } + } + } + + double expected_importance[6]; + for (unsigned int i = 0; i < 6; i++) { + expected_importance[i] = updated[i] - to_spread[i] + received[i]; + } + cout << "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" << endl; + + // ---------------------------------------------------------- + // Compare result with expected result + + for (unsigned int i = 0; i < 6; i++) { + cout << expected_importance[i] << " " << network->get_node_importance(handles[i]) << endl; + EXPECT_TRUE(importance_equals(network->get_node_importance(handles[i]), expected_importance[i])); + } +} diff --git a/src/cpp/tests/test_utils.cc b/src/cpp/tests/test_utils.cc new file mode 100644 index 0000000..4c734e9 --- /dev/null +++ b/src/cpp/tests/test_utils.cc @@ -0,0 +1,77 @@ +#include "expression_hasher.h" +#include "Utils.h" +#include "test_utils.h" +#include +#include + +static char REVERSE_TLB[16] = { + '0', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', +}; + +double random_importance() { + return ((double) rand()) / RAND_MAX; +} + +string random_handle() { + char buffer[HANDLE_HASH_SIZE]; + unsigned int key_size = HANDLE_HASH_SIZE - 1; + for (unsigned int i = 0; i < key_size; i++) { + buffer[i] = REVERSE_TLB[(rand() % 16)]; + } + buffer[key_size] = 0; + string s = buffer; + return s; +} + +string sequential_label(unsigned int &count, string prefix) { + return prefix + std::to_string(count++); +} + +string prefixed_random_handle(string prefix) { + char buffer[HANDLE_HASH_SIZE]; + unsigned int key_size = HANDLE_HASH_SIZE - 1; + for (unsigned int i = 0; i < key_size; i++) { + if (i < prefix.size()) { + buffer[i] = prefix[i]; + } else { + buffer[i] = REVERSE_TLB[(rand() % 16)]; + } + } + buffer[key_size] = 0; + string s = buffer; + return s; +} + +static bool str_comp(string &a, string &b) { + return a.compare(b) < 0; +} + +string *build_handle_space(unsigned int size, bool sort) { + string *answer = new string[size]; + for (unsigned int i = 0; i < size; i++) { + answer[i] = random_handle(); + } + if (sort) { + std::sort(answer, answer + size, str_comp); + } + return answer; +} + +bool double_equals(double v1, double v2) { + return fabs(v2 - v1) < 0.001; +} diff --git a/src/cpp/tests/test_utils.h b/src/cpp/tests/test_utils.h new file mode 100644 index 0000000..2ba9d15 --- /dev/null +++ b/src/cpp/tests/test_utils.h @@ -0,0 +1,16 @@ +#ifndef _ATTENTION_BROKER_SERVER_TESTS_TESTUTILS +#define _ATTENTION_BROKER_SERVER_TESTS_TESTUTILS + +#include +#include + +using namespace std; + +double random_importance(); +string random_handle(); +string sequential_label(unsigned int &count, string prefix = "v"); +string prefixed_random_handle(string prefix); +string *build_handle_space(unsigned int size, bool sort=false); +bool double_equals(double v1, double v2); + +#endif diff --git a/src/cpp/tests/worker_threads_test.cc b/src/cpp/tests/worker_threads_test.cc new file mode 100644 index 0000000..e850b89 --- /dev/null +++ b/src/cpp/tests/worker_threads_test.cc @@ -0,0 +1,221 @@ +#include +#include + +#include "gtest/gtest.h" +#include "common.pb.h" +#include "attention_broker.grpc.pb.h" +#include "attention_broker.pb.h" + +#include "Utils.h" +#include "test_utils.h" +#include "SharedQueue.h" +#include "WorkerThreads.h" +#include "HebbianNetwork.h" + +using namespace attention_broker_server; + +class TestSharedQueue: public SharedQueue { + public: + TestSharedQueue() : SharedQueue() { + } + unsigned int test_current_count() { + return current_count(); + } +}; + + +TEST(WorkerThreads, basics) { + + dasproto::HandleCount *handle_count; + dasproto::HandleList *handle_list; + + unsigned int num_requests = 1000000; + unsigned int wait_for_threads_ms = 500; + + for (double stimulus_prob: {0.0, 0.25, 0.5, 0.75, 1.0}) { + TestSharedQueue *stimulus = new TestSharedQueue(); + TestSharedQueue *correlation = new TestSharedQueue(); + WorkerThreads *pool = new WorkerThreads(stimulus, correlation); + for (unsigned int i = 0; i < num_requests; i++) { + if (Utils::flip_coin(stimulus_prob)) { + handle_count = new dasproto::HandleCount(); + stimulus->enqueue(handle_count); + } else { + handle_list = new dasproto::HandleList(); + handle_list->set_hebbian_network((long) NULL); + correlation->enqueue(handle_list); + } + } + this_thread::sleep_for(chrono::milliseconds(wait_for_threads_ms)); + EXPECT_TRUE(stimulus->test_current_count() == 0); + EXPECT_TRUE(correlation->test_current_count() == 0); + pool->graceful_stop(); + delete pool; + delete stimulus; + delete correlation; + } +} + +TEST(WorkerThreads, hebbian_network_updater_basics) { + + dasproto::HandleList *handle_list; + map node_count; + map edge_count; + HebbianNetwork network; + + TestSharedQueue *stimulus = new TestSharedQueue(); + TestSharedQueue *correlation = new TestSharedQueue(); + WorkerThreads *pool = new WorkerThreads(stimulus, correlation); + + handle_list = new dasproto::HandleList(); + string h1 = random_handle(); + string h2 = random_handle(); + string h3 = random_handle(); + string h4 = random_handle(); + handle_list->add_list(h1); + handle_list->add_list(h2); + handle_list->add_list(h3); + handle_list->add_list(h4); + handle_list->set_hebbian_network((long) &network); + correlation->enqueue(handle_list); + + handle_list = new dasproto::HandleList(); + string h5 = random_handle(); + handle_list->add_list(h1); + handle_list->add_list(h2); + handle_list->add_list(h5); + handle_list->set_hebbian_network((long) &network); + correlation->enqueue(handle_list); + + handle_list = new dasproto::HandleList(); + handle_list->add_list(h2); + handle_list->add_list(h5); + handle_list->set_hebbian_network((long) &network); + correlation->enqueue(handle_list); + + string h6 = random_handle(); + handle_list = new dasproto::HandleList(); + handle_list->add_list(h6); + handle_list->add_list(h6); + handle_list->set_hebbian_network((long) &network); + correlation->enqueue(handle_list); + + handle_list = new dasproto::HandleList(); + handle_list->add_list(h1); + handle_list->add_list(h1); + handle_list->set_hebbian_network((long) &network); + correlation->enqueue(handle_list); + + this_thread::sleep_for(chrono::milliseconds(1000)); + EXPECT_TRUE(correlation->test_current_count() == 0); + + EXPECT_TRUE(network.get_node_count(h1) == 4); + EXPECT_TRUE(network.get_node_count(h2) == 3); + EXPECT_TRUE(network.get_node_count(h3) == 1); + EXPECT_TRUE(network.get_node_count(h4) == 1); + EXPECT_TRUE(network.get_node_count(h5) == 2); + EXPECT_TRUE(network.get_node_count(h6) == 2); + + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h2) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h1) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h3) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h3, h1) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h4) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h4, h1) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h5) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h5, h1) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h5) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h5, h2) == 2); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h3) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h3, h2) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h4) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h4, h2) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h3, h4) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h4, h3) == 1); + EXPECT_TRUE(network.get_asymmetric_edge_count(h1, h1) == 0); + EXPECT_TRUE(network.get_asymmetric_edge_count(h2, h2) == 0); + EXPECT_TRUE(network.get_asymmetric_edge_count(h5, h3) == 0); + EXPECT_TRUE(network.get_asymmetric_edge_count(h3, h5) == 0); + EXPECT_TRUE(network.get_asymmetric_edge_count(h6, h6) == 0); + + pool->graceful_stop(); + + delete pool; + delete stimulus; + delete correlation; +} + +TEST(WorkerThreads, hebbian_network_updater_stress) { + + #define HANDLE_SPACE_SIZE ((unsigned int) 100) + unsigned int num_requests = 10; + unsigned int max_handles_per_request = 10; + unsigned int wait_for_worker_threads_ms = 3000; + + + dasproto::HandleList *handle_list; + string handles[HANDLE_SPACE_SIZE]; + map node_count; + map edge_count; + HebbianNetwork *network = new HebbianNetwork(); + for (unsigned int i = 0; i < HANDLE_SPACE_SIZE; i++) { + handles[i] = random_handle(); + } + + TestSharedQueue *stimulus = new TestSharedQueue(); + TestSharedQueue *correlation = new TestSharedQueue(); + WorkerThreads *pool = new WorkerThreads(stimulus, correlation); + for (unsigned int i = 0; i < num_requests; i++) { + handle_list = new dasproto::HandleList(); + unsigned int num_handles = (rand() % (max_handles_per_request - 1)) + 2; + for (unsigned int j = 0; j < num_handles; j++) { + string h = handles[rand() % HANDLE_SPACE_SIZE]; + handle_list->add_list(h); + if (node_count.find(h) == node_count.end()) { + node_count[h] = 0; + } + node_count[h] = node_count[h] + 1; + } + for (const string &h1: handle_list->list()) { + for (const string &h2: handle_list->list()) { + string composite; + if (h1.compare(h2) < 0) { + composite = h1 + h2; + if (edge_count.find(composite) == edge_count.end()) { + edge_count[composite] = 0; + } + edge_count[composite] = edge_count[composite] + 1; + } + } + } + handle_list->set_hebbian_network((long) network); + correlation->enqueue(handle_list); + } + + this_thread::sleep_for(chrono::milliseconds(wait_for_worker_threads_ms)); + EXPECT_TRUE(stimulus->test_current_count() == 0); + EXPECT_TRUE(correlation->test_current_count() == 0); + return; + + for (unsigned int i = 0; i < HANDLE_SPACE_SIZE; i++) { + for (unsigned int j = 0; j < HANDLE_SPACE_SIZE; j++) { + string h1 = handles[i]; + string h2 = handles[j]; + string composite; + if (h1.compare(h2) < 0) { + composite = h1 + h2; + } else { + composite = h2 + h1; + } + EXPECT_TRUE(network->get_node_count(h1) == node_count[h1]); + EXPECT_TRUE(network->get_node_count(h2) == node_count[h2]); + EXPECT_TRUE(network->get_asymmetric_edge_count(h1, h2) == edge_count[composite]); + EXPECT_TRUE(network->get_asymmetric_edge_count(h2, h1) == edge_count[composite]); + } + } + pool->graceful_stop(); + delete network; + delete pool; + delete stimulus; + delete correlation; +} diff --git a/src/cpp/utils/BUILD b/src/cpp/utils/BUILD new file mode 100644 index 0000000..191db94 --- /dev/null +++ b/src/cpp/utils/BUILD @@ -0,0 +1,10 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "utils_lib", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + includes = ["."], + deps = [ + ], +) diff --git a/src/cpp/utils/SharedQueue.cc b/src/cpp/utils/SharedQueue.cc new file mode 100644 index 0000000..8383f17 --- /dev/null +++ b/src/cpp/utils/SharedQueue.cc @@ -0,0 +1,88 @@ +#include "SharedQueue.h" + +using namespace commons; + +// -------------------------------------------------------------------------------- +// Public methods + +SharedQueue::SharedQueue(unsigned int initial_size) { + size = initial_size; + requests = new void*[size]; + count = 0; + start = 0; + end = 0; +} + +SharedQueue::~SharedQueue() { + delete [] requests; +} + +bool SharedQueue::empty() { + bool answer; + request_queue_mutex.lock(); + answer = (count == 0); + request_queue_mutex.unlock(); + return answer; +} + +void SharedQueue::enqueue(void *request) { + request_queue_mutex.lock(); + if (count == size) { + enlarge_request_queue(); + } + requests[end] = request; + end = (end + 1) % size; + count++; + request_queue_mutex.unlock(); +} + +void *SharedQueue::dequeue() { + void *answer = NULL; + request_queue_mutex.lock(); + if (count > 0) { + answer = requests[start]; + start = (start + 1) % size; + count--; + } + request_queue_mutex.unlock(); + return answer; +} + +// -------------------------------------------------------------------------------- +// Protected methods + +unsigned int SharedQueue::current_size() { + return size; +} + +unsigned int SharedQueue::current_start() { + return start; +} + +unsigned int SharedQueue::current_end() { + return end; +} + +unsigned int SharedQueue::current_count() { + return count; +} + +// -------------------------------------------------------------------------------- +// Private methods + +void SharedQueue::enlarge_request_queue() { + unsigned int _new_size = size * 2; + void **_new_queue = new void*[_new_size]; + unsigned int _cursor = start; + unsigned int _new_cursor = 0; + do { + _new_queue[_new_cursor++] = requests[_cursor]; + _cursor = (_cursor + 1) % size; + } while (_cursor != end); + size = _new_size; + start = 0; + end = _new_cursor; + delete [] requests; + requests = _new_queue; + // count remains unchanged +} diff --git a/src/cpp/utils/SharedQueue.h b/src/cpp/utils/SharedQueue.h new file mode 100644 index 0000000..f17f943 --- /dev/null +++ b/src/cpp/utils/SharedQueue.h @@ -0,0 +1,65 @@ +#ifndef _COMMONS_SHAREDQUEUE_H +#define _COMMONS_SHAREDQUEUE_H + +#include + +namespace commons { + +/** + * Data abstraction of a synchronized (thread-safe) queue for assynchronous requests. + * + * Internally, this abstraction uses an array of requests to avoid the need to create cell + * objects on every insertion. Because of this, on new insertions it's possible to reach queue + * size limit during an insertion. When that happens, the array is doubled in size. Initial size + * is passed as a constructor's parameter. + */ +class SharedQueue { + +public: + + SharedQueue(unsigned int initial_size = 1000); // Basic constructor + + ~SharedQueue(); /// Destructor. + + /** + * Enqueues a request. + * + * @param request Shared to be queued. + */ + void enqueue(void *request); + + /** + * Dequeues a request. + * + * @return The dequeued request. + */ + void *dequeue(); + + /** + * Returns true iff the queue is empty. + */ + bool empty(); + +protected: + + unsigned int current_size(); + unsigned int current_start(); + unsigned int current_end(); + unsigned int current_count(); + +private: + + std::mutex request_queue_mutex; + + void **requests; // GRPC documentation states that request types should not be inherited + unsigned int size; + unsigned int count; + unsigned int start; + unsigned int end; + + void enlarge_request_queue(); +}; + +} // namespace commons + +#endif // _COMMONS_SHAREDQUEUE_H diff --git a/src/cpp/utils/Utils.cc b/src/cpp/utils/Utils.cc new file mode 100644 index 0000000..2210002 --- /dev/null +++ b/src/cpp/utils/Utils.cc @@ -0,0 +1,104 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "Utils.h" + +using namespace commons; +using namespace std; + +// -------------------------------------------------------------------------------- +// Public methods + +Utils::Utils() { +} + +Utils::~Utils() { +} + +void Utils::error(string msg) { + throw runtime_error(msg); +} +void Utils::warning(string msg) { + cerr << msg << endl; +} + +bool Utils::flip_coin(double true_probability) { + long f = 1000; + return (rand() % f) < lround(true_probability * f); +} + +void Utils::sleep(unsigned int milliseconds) { + this_thread::sleep_for(chrono::milliseconds(milliseconds)); +} + +string Utils::get_environment(string const &key) { + char *value = getenv(key.c_str()); + string answer = (value == NULL ? "" : value); + return answer; +} + +StopWatch::StopWatch() { + reset(); +} + +StopWatch::~StopWatch() { +} + +void StopWatch::start() { + if (running) { + stop(); + } + start_time = chrono::steady_clock::now(); + running = true; +} + +void StopWatch::stop() { + if (running) { + chrono::steady_clock::time_point check = chrono::steady_clock::now(); + accumulator = accumulator + (check - start_time); + start_time = check; + running = false; + } +} + +void StopWatch::reset() { + running = false; + accumulator = chrono::steady_clock::duration::zero(); +} + +unsigned long StopWatch::milliseconds() { + return std::chrono::duration_cast(accumulator).count(); +} + +string StopWatch::str_time() { + + unsigned long millis = milliseconds(); + + unsigned long seconds = millis / 1000; + millis = millis % 60; + + unsigned long minutes = seconds / 60; + seconds = seconds % 60; + + unsigned long hours = minutes / 60; + minutes = minutes % 60; + + if (hours > 0) { + return to_string(hours) + " hours " + to_string(minutes) + " mins"; + } else if (minutes > 0) { + return to_string(minutes) + " mins " + to_string(seconds) + " secs"; + } else if (seconds > 0) { + //double s = ((double) ((seconds * 1000) + millis)) / 1000.0; + //std::stringstream stream; + //stream << std::fixed << std::setprecision(3) << s; + //return stream.str() + " secs"; + return to_string(seconds) + " secs " + to_string(millis) + " millis"; + } else { + return to_string(millis) + " millis"; + } +} diff --git a/src/cpp/utils/Utils.h b/src/cpp/utils/Utils.h new file mode 100644 index 0000000..4230dbc --- /dev/null +++ b/src/cpp/utils/Utils.h @@ -0,0 +1,42 @@ +#ifndef _COMMONS_UTILS_H +#define _COMMONS_UTILS_H + +#include +#include + +using namespace std; + +namespace commons { + +class StopWatch { + public: + StopWatch(); + ~StopWatch(); + void start(); + void stop(); + void reset(); + unsigned long milliseconds(); + string str_time(); + private: + bool running; + chrono::steady_clock::time_point start_time; + chrono::steady_clock::duration accumulator; +}; + +class Utils { + +public: + + Utils(); + ~Utils(); + + static void error(string msg); + static void warning(string msg); + static bool flip_coin(double true_probability = 0.5); + static void sleep(unsigned int milliseconds = 100); + static string get_environment(string const &key); +}; + +} // namespace commons + +#endif // _COMMONS_UTILS_H diff --git a/src/docker/Dockerfile b/src/docker/Dockerfile new file mode 100644 index 0000000..babfe48 --- /dev/null +++ b/src/docker/Dockerfile @@ -0,0 +1,98 @@ +FROM ubuntu:22.04 + +ARG BASE_DIR="/opt" +ARG TMP_DIR="/tmp" + +ARG ATTENTION_BROKER_DIR="${BASE_DIR}/das-attention-broker" +ARG DATA_DIR="${BASE_DIR}/data" +ARG GRPC_DIR="${BASE_DIR}/grpc" +ARG PROTO_DIR="${BASE_DIR}/proto" +ARG BAZEL_DIR="${BASE_DIR}/bazel" +ARG THIRDPARTY="${BASE_DIR}/3rd-party" + +RUN mkdir -p ${ATTENTION_BROKER_DIR} +RUN mkdir -p ${DATA_DIR} +RUN mkdir -p ${GRPC_DIR} +RUN mkdir -p ${BAZEL_DIR} +RUN mkdir -p ${PROTO_DIR} +RUN mkdir -p ${THIRDPARTY} +#VOLUME ${ATTENTION_BROKER_DIR} + +RUN apt-get update -y +RUN apt-get install -y git + +RUN cd ${GRPC_DIR} &&\ + git clone https://github.com/grpc/grpc &&\ + cd grpc &&\ + git submodule update --init + +RUN apt-get install -y build-essential +RUN apt-get install -y autoconf +RUN apt-get install -y libtool +RUN apt-get install -y pkg-config +RUN apt-get install -y curl +RUN apt-get install -y gcc +RUN apt-get install -y protobuf-compiler +#RUN apt-get install -y libmbedcrypto7 +RUN apt-get install -y libmbedtls14 +RUN apt-get install -y libmbedtls-dev +RUN apt-get install -y libevent-dev +RUN apt-get install -y libssl-dev + +COPY assets/3rd-party.tgz ${THIRDPARTY} +RUN cd ${THIRDPARTY} &&\ + tar xzvf 3rd-party.tgz &&\ + rm -f 3rd-party.tgz &&\ + mkdir -p ${ATTENTION_BROKER_DIR}/src/3rd-party &&\ + ln -s ${THIRDPARTY} ${ATTENTION_BROKER_DIR}/src/3rd-party &&\ + mv bazelisk ${BAZEL_DIR} + +ENV CPLUS_INCLUDE_PATH="/opt/3rd-party/mbedcrypto/include/" + +ENV CC=/usr/bin/gcc +RUN ln -s ${BAZEL_DIR}/bazelisk /usr/bin/bazel +RUN cd ${GRPC_DIR}/grpc &&\ + ${BAZEL_DIR}/bazelisk build :all + +ADD https://raw.githubusercontent.com/singnet/das-query-engine/master/proto/attention_broker.proto ${PROTO_DIR} +ADD https://raw.githubusercontent.com/singnet/das-query-engine/master/proto/common.proto ${PROTO_DIR} +ADD https://raw.githubusercontent.com/singnet/das-query-engine/master/proto/echo.proto ${PROTO_DIR} + +################################################################################ +# To be removed when AtomDB is properly integrated +# Redis client +RUN apt-get install -y cmake +RUN apt-get install -y libevent-dev +RUN apt-get install -y libssl-dev +RUN apt-get install -y pkg-config +RUN apt-get install -y cmake-data +COPY assets/hiredis-cluster.tgz /tmp +COPY assets/mongo-cxx-driver-r3.11.0.tar.gz /tmp +RUN cd /tmp &&\ + tar xzf hiredis-cluster.tgz &&\ + cd hiredis-cluster &&\ + mkdir build &&\ + cd build &&\ + cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_SSL=ON ..&&\ + make &&\ + make install &&\ + echo "/usr/local/lib" > /etc/ld.so.conf.d/local.conf &&\ + ldconfig +# MongoDB client +RUN cd /tmp &&\ + tar xzvf mongo-cxx-driver-r3.11.0.tar.gz &&\ + cd /tmp/mongo-cxx-driver-r3.11.0/build/ &&\ + cmake .. -DCMAKE_BUILD_TYPE=Release -DMONGOCXX_OVERRIDE_DEFAULT_INSTALL_PREFIX=OFF &&\ + cmake --build . &&\ + cmake --build . --target install && \ +# mv install/include/* /usr/local/include &&\ + ln -s /usr/local/include/bsoncxx/v_noabi/bsoncxx/* /usr/local/include/bsoncxx &&\ + ln -s /usr/local/include/bsoncxx/v_noabi/bsoncxx/third_party/mnmlstc/core/ /usr/local/include/core &&\ + ln -s /usr/local/include/mongocxx/v_noabi/mongocxx/* /usr/local/include/mongocxx/ &&\ +# mv install/lib/* /usr/local/lib &&\ + ldconfig + +################################################################################ + + +WORKDIR /opt/das-attention-broker diff --git a/src/scripts/bazel_build.sh b/src/scripts/bazel_build.sh new file mode 100755 index 0000000..2ad9965 --- /dev/null +++ b/src/scripts/bazel_build.sh @@ -0,0 +1,18 @@ +#!/bin/bash -x + +(( JOBS=$(nproc)/2 )) +BAZELISK_CMD=/opt/bazel/bazelisk +BIN_FOLDER=/opt/das-attention-broker/bin +mkdir -p $BIN_FOLDER + +$BAZELISK_CMD build --jobs $JOBS --noenable_bzlmod //cpp:link_creation_engine \ +&& mv bazel-bin/cpp/link_creation_engine $BIN_FOLDER \ +&& $BAZELISK_CMD build --jobs $JOBS --noenable_bzlmod //cpp:word_query \ +&& mv bazel-bin/cpp/word_query $BIN_FOLDER \ +&& $BAZELISK_CMD build --jobs $JOBS --noenable_bzlmod //cpp:attention_broker_service \ +&& mv bazel-bin/cpp/attention_broker_service $BIN_FOLDER \ +&& $BAZELISK_CMD build --jobs $JOBS --noenable_bzlmod //cpp:query_broker \ +&& mv bazel-bin/cpp/query_broker $BIN_FOLDER \ +&& $BAZELISK_CMD build --jobs $JOBS --noenable_bzlmod //cpp:query \ +&& mv bazel-bin/cpp/query $BIN_FOLDER + diff --git a/src/scripts/bazel_build_command_line.sh b/src/scripts/bazel_build_command_line.sh new file mode 100755 index 0000000..f59a837 --- /dev/null +++ b/src/scripts/bazel_build_command_line.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +echo '============================================================================================================================'; +/opt/bazel/bazelisk build --jobs 6 --noenable_bzlmod :attention_broker diff --git a/src/scripts/bazel_clean.sh b/src/scripts/bazel_clean.sh new file mode 100755 index 0000000..48a1365 --- /dev/null +++ b/src/scripts/bazel_clean.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +/opt/bazel/bazelisk clean +rm -f bazel-src bazel-out bazel-testlogs bazel-bin diff --git a/src/scripts/bazel_test.sh b/src/scripts/bazel_test.sh new file mode 100755 index 0000000..525d54e --- /dev/null +++ b/src/scripts/bazel_test.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +bazel test --noenable_bzlmod --cache_test_results=no //... diff --git a/src/scripts/bazel_test_command_line.sh b/src/scripts/bazel_test_command_line.sh new file mode 100755 index 0000000..6f1dcb9 --- /dev/null +++ b/src/scripts/bazel_test_command_line.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +echo '============================================================================================================================'; +if [ -z "$2" ] + then + echo "Running all tests in $1" + /opt/bazel/bazelisk test --jobs 6 --noenable_bzlmod tests:$1 + else + echo "Running $2 in $1" + /opt/bazel/bazelisk test --jobs 6 --noenable_bzlmod --test_arg=--gtest_filter=$2 tests:$1 +fi diff --git a/src/scripts/build.sh b/src/scripts/build.sh new file mode 100755 index 0000000..8e222d9 --- /dev/null +++ b/src/scripts/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +CONTAINER_NAME="das-attention-broker-build" + +mkdir -p bin +docker run --rm \ + --name=$CONTAINER_NAME \ + --volume .:/opt/das-attention-broker \ + --workdir /opt/das-attention-broker \ + das-attention-broker-builder \ + ./scripts/bazel_build.sh + diff --git a/src/scripts/compile_protos.sh b/src/scripts/compile_protos.sh new file mode 100755 index 0000000..5edf49c --- /dev/null +++ b/src/scripts/compile_protos.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +PROTOS=("echo" "common" "attention_broker") + +for proto in ${PROTOS[@]}; do + protoc -I/opt/proto --cpp_out=/opt/grpc /opt/proto/${proto}.proto + cd /opt/grpc + gcc -c ${proto}.pb.cc +done diff --git a/src/scripts/container_tty.sh b/src/scripts/container_tty.sh new file mode 100755 index 0000000..128ea00 --- /dev/null +++ b/src/scripts/container_tty.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +CONTAINER_NAME="das-attention-broker-bash" + +docker run \ + --net="host" \ + --name=$CONTAINER_NAME \ + --volume /tmp:/tmp \ + --volume .:/opt/das-attention-broker \ + -it das-attention-broker-builder \ + bash + +sleep 1 diff --git a/src/scripts/docker_image_build.sh b/src/scripts/docker_image_build.sh new file mode 100755 index 0000000..b6f736e --- /dev/null +++ b/src/scripts/docker_image_build.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +docker buildx build -t das-attention-broker-builder --load -f docker/Dockerfile . diff --git a/src/scripts/run.sh b/src/scripts/run.sh new file mode 100755 index 0000000..3d44d64 --- /dev/null +++ b/src/scripts/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +CONTAINER_NAME="das-attention-broker-build" + +mkdir -p bin +docker run \ + --name=$CONTAINER_NAME \ + --volume .:/opt/das-attention-broker \ + --workdir /opt/das-attention-broker \ + das-attention-broker-builder \ + ./bin/attention_broker $1 + +sleep 1 +docker rm $CONTAINER_NAME diff --git a/src/scripts/unit_tests.sh b/src/scripts/unit_tests.sh new file mode 100755 index 0000000..12ba71b --- /dev/null +++ b/src/scripts/unit_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +CONTAINER_NAME="das-attention-broker-build" + +mkdir -p bin +docker run \ + --name=$CONTAINER_NAME \ + --volume .:/opt/das-attention-broker \ + --workdir /opt/das-attention-broker/src \ + das-attention-broker-builder \ + ../scripts/bazel_test.sh + +sleep 1 +docker rm $CONTAINER_NAME