diff --git a/README.md b/README.md index 7818a3c..eac80f6 100644 --- a/README.md +++ b/README.md @@ -89,15 +89,6 @@ It is also possible to run a web server that provides this tool's functionality ```bash ./web_server.py ``` -The web server uses mariadb as database. To change it to sqlite, simply change the CONFIG_FILE variable at top of the web_server.py script to `config.json`. -It is recommend to change the following values in `/etc/my.cnf` to make some useful changes: -``` -query_cache_type = 1 -query_cache_size = 192M -innodb_buffer_pool_size = 8G -thread_handling=pool-of-threads -``` -innodb_buffer_pool_size should be set to approximately 80% of available memory (https://mariadb.com/kb/en/innodb-system-variables/#innodb_buffer_pool_size) ```bash gunicorn --worker-class=gevent --worker-connections=50 --workers=3 --bind '0.0.0.0:8000' wsgi:app ``` @@ -108,5 +99,22 @@ Finally, you can also use Nginx as a reverse proxy. A sample configuration file gunicorn --worker-class=gevent --worker-connections=50 --workers=3 --bind 'unix:/tmp/gunicorn.sock' wsgi:app ``` +## MariaDB as Second Database Option +As alternative to the preconfigured SQLite, you can use *MariaDB* as database. A sample configuration file for MariaDB is provided in [``config_mariadb.json``](https://github.com/ra1nb0rn/search_vulns/blob/master/config_mariadb.json). + +Make sure that you adjust the values for MariaDB in the configuration file to your MariaDB configuration (*user*, *password*, *host*, *port*). + +To use MariaDB instead of *SQLite* for the webserver, simply change the CONFIG_FILE variable in ``web_server.py`` to your config file (e.g. ``config_mariadb.json``). +It is recommend to change the following values in ``/etc/my.cnf`` to improve the performance of MariaDB: +``` +[mariadb] + +query_cache_type = 1 +query_cache_size = 192M +innodb_buffer_pool_size = 8G +thread_handling = pool-of-threads +``` +`innodb_buffer_pool_size` should be set to approximately 80% of available memory (see [the official documentation](https://mariadb.com/kb/en/innodb-system-variables/#innodb_buffer_pool_size)). + ## License *search_vulns* is licensed under the MIT license, see [here](https://github.com/ra1nb0rn/search_vulns/blob/master/LICENSE). diff --git a/config.json b/config.json index f0adf17..5bb2b6e 100644 --- a/config.json +++ b/config.json @@ -2,12 +2,10 @@ "DATABASE_NAME": "vulndb.db3", "MAN_EQUIVALENT_CPES_FILE": "man_equiv_cpes.json", "CVE_EDB_MAP_FILE": "cveid_to_edbid.json", - "CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json", "NVD_API_KEY": "", "cpe_search": { "DATABASE_NAME": "cpe_search/cpe-search-dictionary.db3", "DEPRECATED_CPES_FILE": "cpe_search/deprecated-cpes.json", - "CREATE_SQL_STATEMENTS_FILE": "cpe_search/create_sql_statements.json", "NVD_API_KEY": "" }, "DATABASE": { diff --git a/config_mariadb.json b/config_mariadb.json index ea726fc..e5d18de 100644 --- a/config_mariadb.json +++ b/config_mariadb.json @@ -2,12 +2,10 @@ "DATABASE_NAME": "vulndb", "MAN_EQUIVALENT_CPES_FILE": "man_equiv_cpes.json", "CVE_EDB_MAP_FILE": "cveid_to_edbid.json", - "CREATE_SQL_STATEMENTS_FILE": "create_sql_statements.json", "NVD_API_KEY": "", "cpe_search": { "DATABASE_NAME": "cpe_search_dictionary", "DEPRECATED_CPES_FILE": "cpe_search/deprecated-cpes.json", - "CREATE_SQL_STATEMENTS_FILE": "cpe_search/create_sql_statements.json", "NVD_API_KEY": "" }, "DATABASE": { diff --git a/db_creation_src/CMakeLists.txt b/db_creation_src/CMakeLists.txt index a27c8d3..9fbc2c1 100644 --- a/db_creation_src/CMakeLists.txt +++ b/db_creation_src/CMakeLists.txt @@ -9,9 +9,9 @@ include_directories(SQLiteCpp/include) add_subdirectory(SQLiteCpp) # Include mariadb-connector-cpp library -set(CONC_WITH_UNIT_TESTS "Off") +set(CONC_WITH_UNIT_TESTS OFF) set(CMAKE_BUILD_TYPE "RelWithDebInfo") -set(WITH_UNIT_TESTS "Off") +set(WITH_UNIT_TESTS OFF CACHE INTERNAL "") include_directories(mariadb-connector-cpp/include) # workaround until mariadb fix issue in test/CMakeLists.txt include_directories("${CMAKE_BINARY_DIR}/mariadb-connector-cpp/test") diff --git a/db_creation_src/create_db.cpp b/db_creation_src/create_db.cpp index 70359bf..1406083 100644 --- a/db_creation_src/create_db.cpp +++ b/db_creation_src/create_db.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "database_wrapper.h" #include "prepared_statement.h" @@ -61,6 +62,18 @@ void handle_exception(T &e) { } } +bool is_safe_database_name(std::string dbName) { + // Check if database name contains any special characters or keywords + std::regex pattern("[^a-zA-Z0-9_-]"); + + if (std::regex_search(dbName, pattern)) { + return false; // Database name is malicious + } + + return true; // Database name is safe +} + + int add_to_db(DatabaseWrapper *db, const std::string &filepath) { // Begin transaction db->start_transaction(); @@ -253,7 +266,7 @@ int add_to_db(DatabaseWrapper *db, const std::string &filepath) { else cve_cpe_query->bind(6, false); - try{ + try { cve_cpe_query->execute(); } catch (SQLite::Exception& e) { @@ -314,9 +327,15 @@ int main(int argc, char *argv[]) { std::unique_ptr db; + // validate given database name + if (database_type != "sqlite" && !is_safe_database_name(config["DATABASE_NAME"])) { + std::cout << "Potentially malicious database name detected. Abort creation of database" << std::endl; + return EXIT_FAILURE; + } + try{ // create database connection - if ( database_type == "sqlite") + if (database_type == "sqlite") db = std::make_unique(outfile); else{ db = std::make_unique(config); @@ -360,4 +379,15 @@ int main(int argc, char *argv[]) { } // close database connection db->close_connection(); + + // print duration of building process + auto time = std::chrono::high_resolution_clock::now() - start_time; + + char *db_abs_path = realpath(outfile.c_str(), NULL); + std::cout << "Database creation took " << + (float) (std::chrono::duration_cast(time).count()) / (1e6) << "s .\n"; + std::cout << "Local copy of NVD created as " << db_abs_path << " ." << std::endl; + free(db_abs_path); + return EXIT_SUCCESS; + } diff --git a/db_creation_src/database_wrapper.cpp b/db_creation_src/database_wrapper.cpp index 46a3335..2315db4 100644 --- a/db_creation_src/database_wrapper.cpp +++ b/db_creation_src/database_wrapper.cpp @@ -31,10 +31,10 @@ void SQLiteDB::execute_query(std::string query) { void SQLiteDB::create_prepared_statements() { // init prepared statements - cve_query = std::make_unique(*db, CVE_QUERY_FRAGMENT); - cve_cpe_query = std::make_unique(*db, CVE_CPE_QUERY_FRAGMENT); - add_exploit_ref_query = std::make_unique(*db, NVD_EXPLOIT_REFS_FRAGMENT); - add_cveid_exploit_ref_query = std::make_unique(*db, CVE_NVD_EXPLOITS_REFS_FRAGMENT); + cve_query = std::make_unique(*db, CVE_QUERY_FRAGMENT); + cve_cpe_query = std::make_unique(*db, CVE_CPE_QUERY_FRAGMENT); + add_exploit_ref_query = std::make_unique(*db, NVD_EXPLOIT_REFS_FRAGMENT); + add_cveid_exploit_ref_query = std::make_unique(*db, CVE_NVD_EXPLOITS_REFS_FRAGMENT); } void SQLiteDB::commit() { @@ -63,7 +63,8 @@ MariaDB::MariaDB(nlohmann::json config) { sql::Properties properties({{"user", user}, {"password", password}}); conn = std::unique_ptr(driver->connect(url, properties)); - + + // set up database std::string create_db_query = "CREATE OR REPLACE DATABASE "+database+";"; std::unique_ptr stmnt(conn->createStatement()); stmnt->executeQuery(create_db_query); @@ -77,10 +78,10 @@ void MariaDB::execute_query(std::string query) { void MariaDB::create_prepared_statements() { // init prepared statements - cve_query = std::make_unique(conn, CVE_QUERY_FRAGMENT); - cve_cpe_query = std::make_unique(conn, CVE_CPE_QUERY_FRAGMENT); - add_exploit_ref_query = std::make_unique(conn, NVD_EXPLOIT_REFS_FRAGMENT); - add_cveid_exploit_ref_query = std::make_unique(conn, CVE_NVD_EXPLOITS_REFS_FRAGMENT); + cve_query = std::make_unique(conn, CVE_QUERY_FRAGMENT); + cve_cpe_query = std::make_unique(conn, CVE_CPE_QUERY_FRAGMENT); + add_exploit_ref_query = std::make_unique(conn, NVD_EXPLOIT_REFS_FRAGMENT); + add_cveid_exploit_ref_query = std::make_unique(conn, CVE_NVD_EXPLOITS_REFS_FRAGMENT); } void MariaDB::commit() { diff --git a/export_database_as_csv.py b/export_database_as_csv.py index b7e20dd..68892fd 100644 --- a/export_database_as_csv.py +++ b/export_database_as_csv.py @@ -1,12 +1,17 @@ import sqlite3 try: # only use mariadb module if installed import mariadb -except: +except ImportError: pass import csv import sys import os + +def is_safe_table_name(table_name): + return all([c.isalnum() or c in ('-', '_') for c in table_name]) + + def export_tables_to_csv(database_file): # Connect to the SQLite database conn = sqlite3.connect(database_file) @@ -19,8 +24,12 @@ def export_tables_to_csv(database_file): # Export each table to a separate CSV file for table in tables: table_name = table[0] - csv_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), f"{table_name}.csv") - cursor.execute(f"SELECT * FROM {table_name};") + # check if table_name is not malicious + if not is_safe_table_name(table_name): + print('Malicious table name detected') + return 1 + csv_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), f'{table_name}.csv') + cursor.execute('SELECT * FROM %s;' % table_name) rows = cursor.fetchall() with open(csv_file, 'w', newline='\n') as file: writer = csv.writer(file, dialect='unix', escapechar='\\') @@ -44,14 +53,18 @@ def export_tables_mariadb_to_csv(config): cursor = conn.cursor() # Get the list of tables in the database - cursor.execute("SHOW TABLES;") + cursor.execute('SHOW TABLES;') tables = cursor.fetchall() # Export each table to a separate CSV file for table in tables: table_name = table[0] - csv_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), f"{table_name}.csv.mariadb") - cursor.execute(f"SELECT * FROM {table_name};") + # check if table_name is not malicious + if not is_safe_table_name(table_name): + print('Malicious table name detected') + return 1 + csv_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), f'{table_name}.csv.mariadb') + cursor.execute('SELECT * FROM %s;' % table_name) rows = cursor.fetchall() with open(csv_file, 'w', newline='\n') as file: diff --git a/install.sh b/install.sh index 2ef4a04..4be7add 100755 --- a/install.sh +++ b/install.sh @@ -7,7 +7,7 @@ LINUX_PACKAGE_MANAGER="apt-get" install_linux_packages() { # Install required packages - PACKAGES="python3 python3-pip wget curl sqlite3 libsqlite3-dev cmake gcc sudo apt-get install libmariadb-dev jq" + PACKAGES="python3 python3-pip wget curl sqlite3 libsqlite3-dev cmake gcc libmariadb-dev jq" which ${LINUX_PACKAGE_MANAGER} &> /dev/null if [ $? != 0 ]; then printf "${RED}Could not find ${LINUX_PACKAGE_MANAGER} command.\\nPlease specify your package manager at the start of the script.\\n${SANE}" @@ -50,7 +50,7 @@ setup_create_db() { fi cd ".." - ## configure submodules of SQLiteCpp for create_db + ## configure submodules of mariadb-connector-cpp for create_db cd "mariadb-connector-cpp" if [ $QUIET != 1 ]; then git submodule init diff --git a/migrate_sqlite_to_mariadb.sh b/migrate_sqlite_to_mariadb.sh index f7ef805..b211669 100755 --- a/migrate_sqlite_to_mariadb.sh +++ b/migrate_sqlite_to_mariadb.sh @@ -57,17 +57,17 @@ PASSWORD=$(jq -r '.DATABASE.PASSWORD' $CONFIG_FILE) PORT=$(jq -r '.DATABASE.PORT' $CONFIG_FILE) DATABASE_NAME=$(jq -r '.DATABASE_NAME' $CONFIG_FILE) CPE_DATABASE_NAME=$(jq -r '.cpe_search.DATABASE_NAME' $CONFIG_FILE) -CREATE_TABLES_QUERIES_VULNDB=$ABS_PATH/$(jq -r '.CREATE_SQL_STATEMENTS_FILE' $CONFIG_FILE) -CREATE_TABLES_QUERIES_CPE_SEARCH=$ABS_PATH/$(jq -r '.cpe_search.CREATE_SQL_STATEMENTS_FILE' $CONFIG_FILE) +CREATE_TABLES_QUERIES_VULNDB=$ABS_PATH/create_sql_statements.json +CREATE_TABLES_QUERIES_CPE_SEARCH=$ABS_PATH/cpe_search/create_sql_statements.json # Export sqlite databases -echo "[+] Export sqlite as csv" -python3 $ABS_PATH/export_database_as_csv.py sqlite $DATABASE_FILE -python3 $ABS_PATH/export_database_as_csv.py sqlite $CPE_DATABASE_FILE +echo "[+] Export SQLite as csv" +python3 $ABS_PATH/export_database_as_csv.py sqlite $DATABASE_FILE || { rm $ABS_PATH/*.csv; echo "[-] Migration failed"; exit 1; } +python3 $ABS_PATH/export_database_as_csv.py sqlite $CPE_DATABASE_FILE || { rm $ABS_PATH/*.csv; echo "[-] Migration failed"; exit 1; } # Create databases -echo "[+] Add data to mariadb" +echo "[+] Add data to MariaDB" # get queries from file vulndb_create_tables_queries=$(cat $CREATE_TABLES_QUERIES_VULNDB | jq '.TABLES | .[] | select(.mariadb) | .mariadb'| tr '\n' ' ' | sed 's/"//g') cpe_search_create_tables_queries=$(cat $CREATE_TABLES_QUERIES_CPE_SEARCH | jq '.TABLES | .[] | select(.mariadb) | .mariadb' | tr '\n' ' ' | sed 's/"//g') @@ -85,9 +85,9 @@ vulndb_create_views=$(cat $CREATE_TABLES_QUERIES_VULNDB | jq '.VIEWS | .[] | sel mariadb -u $USER --password=$PASSWORD -h $HOST -P $PORT -D "$DATABASE_NAME" -e "$vulndb_create_views" # Export mariadb databases -echo "[+] Export mariadb as csv" -python3 $ABS_PATH/export_database_as_csv.py mariadb $DATABASE_NAME,$USER,$PASSWORD,$HOST,$PORT -python3 $ABS_PATH/export_database_as_csv.py mariadb $CPE_DATABASE_NAME,$USER,$PASSWORD,$HOST,$PORT +echo "[+] Export MariaDB as csv" +python3 $ABS_PATH/export_database_as_csv.py mariadb $DATABASE_NAME,$USER,$PASSWORD,$HOST,$PORT || { rm $ABS_PATH/*.csv $ABS_PATH/*.csv.mariadb; echo "[-] Migration failed"; exit 1; } +python3 $ABS_PATH/export_database_as_csv.py mariadb $CPE_DATABASE_NAME,$USER,$PASSWORD,$HOST,$PORT|| { rm $ABS_PATH/*.csv $ABS_PATH/*.csv.mariadb; echo "[-] Migration failed"; exit 1; } # check whether everything migrated correctly # Loop through each CSV file in the current folder diff --git a/search_vulns.py b/search_vulns.py index f6c6a64..f5b482a 100755 --- a/search_vulns.py +++ b/search_vulns.py @@ -125,9 +125,9 @@ def get_vuln_details(db_cursor, vulns, add_other_exploit_refs): query = 'SELECT edb_ids, description, published, last_modified, cvss_version, base_score, vector FROM cve WHERE cve_id = ?' db_cursor.execute(query, (cve_id,)) edb_ids, descr, publ, last_mod, cvss_ver, score, vector = db_cursor.fetchone() - detailed_vulns[cve_id] = {"id": cve_id, "description": descr, "published": publ, "modified": last_mod, - "href": "https://nvd.nist.gov/vuln/detail/%s" % cve_id, "cvss_ver": cvss_ver, - "cvss": score, "cvss_vec": vector, 'vuln_match_reason': match_reason} + detailed_vulns[cve_id] = {"id": cve_id, "description": descr, "published": str(publ), "modified": str(last_mod), + "href": "https://nvd.nist.gov/vuln/detail/%s" % cve_id, "cvss_ver": str(float(cvss_ver)), + "cvss": str(float(score)), "cvss_vec": vector, 'vuln_match_reason': match_reason} edb_ids = edb_ids.strip() if edb_ids: @@ -290,7 +290,7 @@ def print_vulns(vulns, to_string=False): print_str += len("Exploits: ") * " " + edb_link + "\n" print_str += "Reference: " + vuln_node["href"] - print_str += ", " + str(vuln_node["published"]).split(" ")[0] + print_str += ", " + vuln_node["published"].split(" ")[0] if not to_string: printit(print_str) else: @@ -660,4 +660,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_related_queries.py b/tests/test_related_queries.py index 476f5ad..9696894 100755 --- a/tests/test_related_queries.py +++ b/tests/test_related_queries.py @@ -59,7 +59,7 @@ def test_search_dell_omsa_9_4_0_2(self): self.maxDiff = None query = 'dell omsa 9.4.0.2' result = search_vulns.search_vulns_return_cpe(query) - expected_related_cpes = [('cpe:2.3:a:dell:openmanage_server_administrator:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage_server_administrator:5.2.0:*:*:*:*:*:*:*', 0.9356286465015572), ('cpe:2.3:a:dell:openmanage_server_administrator:1.00.0000:*:*:*:*:*:*:*', 0.8845604348848455), ('cpe:2.3:a:dell:openmanage_server_administrator_installer:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage_server_administrator_installer:1.0.0:*:*:*:*:*:*:*', 0.8355902246901327), ('cpe:2.3:a:dell:openmanage_server_administrator_lite:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage_server_administrator_lite:5.4.1:*:*:*:*:*:*:*', 0.8355902246901327), ('cpe:2.3:a:dell:openmanage:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage:-:*:*:*:*:*:*:*', 0.826938997039739)] + expected_related_cpes = [('cpe:2.3:a:dell:openmanage_server_administrator:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage_server_administrator:5.2.0:*:*:*:*:*:*:*', 0.9356286465015572), ('cpe:2.3:a:dell:openmanage_server_administrator:1.00.0000:*:*:*:*:*:*:*', 0.8845604348848455), ('cpe:2.3:a:dell:openmanage_server_administrator_installer:9.4.0.2:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:dell:openmanage_server_administrator_installer:1.0.0:*:*:*:*:*:*:*', 0.8355902246901327)] for i, (expected_related_cpe, match_score) in enumerate(expected_related_cpes): self.assertEqual(expected_related_cpe, result[query]['pot_cpes'][i][0]) self.assertAlmostEqual(match_score, result[query]['pot_cpes'][i][1]) @@ -68,7 +68,7 @@ def test_search_citrix_adc_13_1_42_47(self): self.maxDiff = None query = 'citrix adc 13.1-42.47' result = search_vulns.search_vulns_return_cpe(query) - expected_related_cpes = [('cpe:2.3:a:citrix:application_delivery_controller:13.1:42.47:*:*:-:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:13.1-42.47:*:*:*:-:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:13.1:*:*:*:-:*:*:*', 0.9443356111798747), ('cpe:2.3:h:citrix:application_delivery_controller:13.1:*:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:13.1:42.47:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:13.1-42.47:*:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:-:*:*:*:*:*:*:*', 0.9195900759823716), ('cpe:2.3:a:citrix:application_delivery_controller:13.1-21.50:*:*:*:*:*:*:*', 0.8959514540639271), ('cpe:2.3:a:citrix:application_delivery_controller:12.1:*:*:*:-:*:*:*', 0.8588472122359783), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:13.1:*:*:*:*:*:*:*', -1), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:13.1:42.47:*:*:*:*:*:*', -1), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:13.1-42.47:*:*:*:*:*:*:*', -1), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:10.1:*:*:*:*:*:*:*', 0.8212665153916354)] + expected_related_cpes = [('cpe:2.3:a:citrix:application_delivery_controller:13.1:42.47:*:*:-:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:13.1-42.47:*:*:*:-:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:13.1:*:*:*:-:*:*:*', 0.9443356111798746), ('cpe:2.3:h:citrix:application_delivery_controller:13.1:*:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:13.1:42.47:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:13.1-42.47:*:*:*:*:*:*:*', -1), ('cpe:2.3:h:citrix:application_delivery_controller:-:*:*:*:*:*:*:*', 0.9195900759823715)] for i, (expected_related_cpe, match_score) in enumerate(expected_related_cpes): self.assertEqual(expected_related_cpe, result[query]['pot_cpes'][i][0]) self.assertAlmostEqual(match_score, result[query]['pot_cpes'][i][1]) @@ -77,7 +77,7 @@ def test_search_citrix_adc_no_version(self): self.maxDiff = None query = 'citrix adc' result = search_vulns.search_vulns_return_cpe(query) - expected_related_cpes = [('cpe:2.3:h:citrix:application_delivery_controller:-:*:*:*:*:*:*:*', 0.9640085266327638), ('cpe:2.3:a:citrix:application_delivery_controller:*:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:12.1:*:*:*:-:*:*:*', 0.9003316339465728), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:*:*:*:*:*:*:*:*', -1), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:10.1:*:*:*:*:*:*:*', 0.8609356974951642), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:10.5:*:*:*:*:*:*:*', 0.8609356974951642), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:10.5e:*:*:*:*:*:*:*', 0.8609356974951642)] + expected_related_cpes = [('cpe:2.3:h:citrix:application_delivery_controller:-:*:*:*:*:*:*:*', 0.9640085266327638), ('cpe:2.3:a:citrix:application_delivery_controller:*:*:*:*:*:*:*:*', -1), ('cpe:2.3:a:citrix:application_delivery_controller:12.1:*:*:*:-:*:*:*', 0.9003316339465728), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:*:*:*:*:*:*:*:*', -1), ('cpe:2.3:o:citrix:application_delivery_controller_firmware:10.1:*:*:*:*:*:*:*', 0.8609356974951642)] for i, (expected_related_cpe, match_score) in enumerate(expected_related_cpes): self.assertEqual(expected_related_cpe, result[query]['pot_cpes'][i][0]) self.assertAlmostEqual(match_score, result[query]['pot_cpes'][i][1]) @@ -93,4 +93,5 @@ def test_search_openssh_83_p4(self): if __name__ == '__main__': + os.environ['IS_CPE_SEARCH_TEST'] = 'true' unittest.main() diff --git a/updater.py b/updater.py index c9327f4..54c43da 100755 --- a/updater.py +++ b/updater.py @@ -33,7 +33,8 @@ REQUEST_HEADERS = {"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/62.0"} NVD_UPDATE_SUCCESS = None CVE_API_URL = "https://services.nvd.nist.gov/rest/json/cves/2.0" -MARIADB_BACKUP_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mariadb_dump.sql") +MARIADB_BACKUP_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mariadb_dump.sql') +CREATE_SQL_STATEMENTS_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'create_sql_statements.json') QUIET = False DEBUG = False API_RESULTS_PER_PAGE = 2000 @@ -98,10 +99,17 @@ def backup_mariadb_database(database): except: pass else: - # backup mariadb - return_code = subprocess.call(['mariadb-dump', '-u', CONFIG['DATABASE']['USER'], f"--password={CONFIG['DATABASE']['PASSWORD']}", - '-h', CONFIG['DATABASE']['HOST'], '-P', str(CONFIG['DATABASE']['PORT']), - '--add-drop-database', '--add-locks', '-B', database, '-r', MARIADB_BACKUP_FILE], stderr=subprocess.DEVNULL) + # backup MariaDB + backup_call = ['mariadb-dump', + '-u', CONFIG['DATABASE']['USER'], + '-h', CONFIG['DATABASE']['HOST'], + '-P', str(CONFIG['DATABASE']['PORT']), + '--add-drop-database', '--add-locks', + '-B', database, '-r', MARIADB_BACKUP_FILE] + if CONFIG['DATABASE']['PASSWORD']: + backup_call.append('-p') + backup_call.append(CONFIG['DATABASE']['PASSWORD']) + return_code = subprocess.call(backup_call, stderr=subprocess.DEVNULL) if return_code != 0: print(f'[-] MariaDB BackUp of {database} failed') @@ -111,7 +119,7 @@ async def update_vuln_db(nvd_api_key=None): global NVD_UPDATE_SUCCESS - # backup mariadb + # backup MariaDB if CONFIG['DATABASE']['TYPE'] == 'mariadb': backup_mariadb_database(CONFIG['DATABASE_NAME']) elif os.path.isfile(CONFIG['DATABASE_NAME']): @@ -186,7 +194,7 @@ async def update_vuln_db(nvd_api_key=None): # build local NVD copy with downloaded data feeds print('[+] Building vulnerability database') - create_db_call = ["./create_db", NVD_DATAFEED_DIR, CONFIG_FILE, CONFIG['DATABASE_NAME'], CONFIG['CREATE_SQL_STATEMENTS_FILE']] + create_db_call = ["./create_db", NVD_DATAFEED_DIR, CONFIG_FILE, CONFIG['DATABASE_NAME'], CREATE_SQL_STATEMENTS_FILE] with open(os.devnull, "w") as outfile: return_code = subprocess.call(create_db_call, stdout=outfile, stderr=subprocess.STDOUT) @@ -230,13 +238,29 @@ async def handle_cpes_update(nvd_api_key=None): if os.path.isfile(CONFIG['DEPRECATED_CPES_BACKUP_FILE']): shutil.move(CONFIG['DEPRECATED_CPES_BACKUP_FILE'], CONFIG['cpe_search']['DEPRECATED_CPES_FILE']) if os.path.isfile(MARIADB_BACKUP_FILE): - restore_call = f'''mariadb -u {CONFIG['DATABASE']['USER']} --password={CONFIG['DATABASE']['PASSWORD']} -h {CONFIG['DATABASE']['HOST']} -P {str(CONFIG['DATABASE']['PORT'])} < {MARIADB_BACKUP_FILE}''' - return_code = subprocess.call(restore_call, shell=True, stderr=subprocess.DEVNULL) - if return_code != 0: - print('[-] Failed to restore mariadb') + with open(MARIADB_BACKUP_FILE, 'rb') as f: + mariadb_backup_data = f.read() + restore_call = ['mariadb', '-u', CONFIG['DATABASE']['USER'], + '-h', CONFIG['DATABASE']['HOST'], + '-P', str(CONFIG['DATABASE']['PORT'])] + if CONFIG['DATABASE']['PASSWORD']: + restore_call.append('-p') + restore_call.append(CONFIG['DATABASE']['PASSWORD']) + restore_call_run = subprocess.run(restore_call, input=mariadb_backup_data) + if restore_call_run.returncode != 0: + print('[-] Failed to restore MariaDB') else: - print('[+] Restored mariadb from backup') - os.remove(MARIADB_BACKUP_FILE) + print('[+] Restored MariaDB from backup') + + # Restore failed b/c database is down -> not delete backup file + # check whether database is up by trying to get a connection + try: + get_database_connection(CONFIG['DATABASE'], CONFIG['cpe_search']['DATABASE_NAME']) + except: + print('[!] MariaDB seems to be down. The backup file wasn\'t deleted. To restore manually from the file, run the following command:') + print(' '.join(restore_call+['<', MARIADB_BACKUP_FILE, '&&', 'rm', MARIADB_BACKUP_FILE])) + else: + os.remove(MARIADB_BACKUP_FILE) else: if os.path.isfile(CONFIG['CPE_DATABASE_BACKUP_FILE']): os.remove(CONFIG['CPE_DATABASE_BACKUP_FILE']) @@ -252,20 +276,37 @@ def rollback(): """Rollback the DB / module update""" communicate_warning('An error occured, rolling back database update') - if os.path.isfile(CONFIG['DATABASE_NAME']): - os.remove(CONFIG['DATABASE_NAME']) + if CONFIG['DATABASE']['TYPE'] == 'sqlite': + if os.path.isfile(CONFIG['DATABASE_NAME']): + os.remove(CONFIG['DATABASE_NAME']) if os.path.isfile(CONFIG['DATABASE_BACKUP_FILE']): shutil.move(CONFIG['DATABASE_BACKUP_FILE'], CONFIG['DATABASE_NAME']) if os.path.isdir(NVD_DATAFEED_DIR): shutil.rmtree(NVD_DATAFEED_DIR) if os.path.isfile(MARIADB_BACKUP_FILE): - restore_call = f'''mariadb -u {CONFIG['DATABASE']['USER']} --password={CONFIG['DATABASE']['PASSWORD']} -h {CONFIG['DATABASE']['HOST']} -P {str(CONFIG['DATABASE']['PORT'])} < {MARIADB_BACKUP_FILE}''' - return_code = subprocess.call([restore_call], shell=True, stderr=subprocess.DEVNULL) - if return_code != 0: - print('[-] Failed to restore mariadb') + with open(MARIADB_BACKUP_FILE, 'rb') as f: + mariadb_backup_data = f.read() + restore_call = ['mariadb', '-u', CONFIG['DATABASE']['USER'], + '-h', CONFIG['DATABASE']['HOST'], + '-P', str(CONFIG['DATABASE']['PORT'])] + if CONFIG['DATABASE']['PASSWORD']: + restore_call.append('-p') + restore_call.append(CONFIG['DATABASE']['PASSWORD']) + restore_call_run = subprocess.run(restore_call, input=mariadb_backup_data) + if restore_call_run.returncode != 0: + print('[-] Failed to restore MariaDB') else: - print('[+] Restored mariadb from backup') - os.remove(MARIADB_BACKUP_FILE) + print('[+] Restored MariaDB from backup') + + # Restore failed b/c database is down -> not delete backup file + # check whether database is up by trying to get a connection + try: + get_database_connection(CONFIG['DATABASE'], CONFIG['cpe_search']['DATABASE_NAME']) + except: + print('[!] MariaDB seems to be down. The backup file wasn\'t deleted. To restore manually from the file, run the following command:') + print(' '.join(restore_call+['<', MARIADB_BACKUP_FILE, '&&', 'rm', MARIADB_BACKUP_FILE])) + else: + os.remove(MARIADB_BACKUP_FILE) def communicate_warning(msg: str): @@ -417,7 +458,7 @@ def create_poc_in_github_table(): db_conn = get_database_connection(CONFIG['DATABASE'], CONFIG['DATABASE_NAME']) db_cursor = db_conn.cursor() create_poc_in_github_table = CREATE_SQL_STATEMENTS['TABLES']['CVE_POC_IN_GITHUB_MAP'][CONFIG['DATABASE']['TYPE']] - # necessary because sqlite can't handle more than one query a time + # necessary because SQLite can't handle more than one query a time for query in create_poc_in_github_table[:-1].split(';'): db_cursor.execute(query+';') db_conn.commit() @@ -461,7 +502,7 @@ def run(full=False, nvd_api_key=None, config_file=''): CONFIG['CPE_DATABASE_BACKUP_FILE'] = CONFIG['cpe_search']['DATABASE_NAME'] + '.bak' CONFIG['DEPRECATED_CPES_BACKUP_FILE'] = CONFIG['cpe_search']['DEPRECATED_CPES_FILE'] + '.bak' - with open(CONFIG['CREATE_SQL_STATEMENTS_FILE']) as f: + with open(CREATE_SQL_STATEMENTS_FILE) as f: CREATE_SQL_STATEMENTS = json.loads(f.read()) # create file dirs as needed @@ -471,7 +512,7 @@ def run(full=False, nvd_api_key=None, config_file=''): update_files += [CONFIG['DATABASE_NAME'], CONFIG['cpe_search']['DATABASE_NAME']] for file in update_files: os.makedirs(os.path.dirname(file), exist_ok=True) - + if full: if not nvd_api_key: nvd_api_key = os.getenv('NVD_API_KEY') @@ -509,7 +550,7 @@ def run(full=False, nvd_api_key=None, config_file=''): shutil.move(CONFIG['cpe_search']['DEPRECATED_CPES_FILE'], CONFIG['DEPRECATED_CPES_BACKUP_FILE']) if os.path.isfile(CONFIG['DATABASE_NAME']): shutil.move(CONFIG['DATABASE_NAME'], CONFIG['DATABASE_BACKUP_FILE']) - # backup mariadb + # backup MariaDB if CONFIG['DATABASE']['TYPE'] == 'mariadb': try: # check whether database exists @@ -518,14 +559,25 @@ def run(full=False, nvd_api_key=None, config_file=''): except: pass else: - # backup mariadb - return_code = subprocess.call(['mariadb-dump', '-u', CONFIG['DATABASE']['USER'], f"--password={CONFIG['DATABASE']['PASSWORD']}", '-h', CONFIG['DATABASE']['HOST'], '-P', str(CONFIG['DATABASE']['PORT']), - '--add-drop-database', '--add-locks', '-B', CONFIG['DATABASE_NAME'], '-B', CONFIG['cpe_search']['DATABASE_NAME'], '-r', MARIADB_BACKUP_FILE], stderr=subprocess.DEVNULL) + # backup MariaDB + backup_call = ['mariadb-dump', + '-u', CONFIG['DATABASE']['USER'], + '-h', CONFIG['DATABASE']['HOST'], + '-P', str(CONFIG['DATABASE']['PORT']), + '--add-drop-database', '--add-locks', + '-B', CONFIG['DATABASE_NAME'], + '-B', CONFIG['cpe_search']['DATABASE_NAME'], + '-r', MARIADB_BACKUP_FILE] + if CONFIG['DATABASE']['PASSWORD']: + backup_call.append('-p') + backup_call.append(CONFIG['DATABASE']['PASSWORD']) + return_code = subprocess.call(backup_call, stderr=subprocess.DEVNULL) if return_code != 0: print(f'[-] MariaDB backup failed') # expand paths CONFIG['DATABASE_NAME'] = os.path.join(os.path.dirname(os.path.abspath(config_file)), CONFIG['DATABASE_NAME']) CONFIG['cpe_search']['DATABASE_NAME'] = os.path.join(os.path.dirname(os.path.abspath(config_file)), CONFIG['cpe_search']['DATABASE_NAME']) + try: quiet_flag = "" if QUIET: @@ -546,6 +598,7 @@ def run(full=False, nvd_api_key=None, config_file=''): shlex.quote(CONFIG['DATABASE_NAME'])), shell=True) if return_code != 0: raise(Exception("Could not download latest resource files")) + if os.path.isfile(CONFIG['CPE_DATABASE_BACKUP_FILE']): os.remove(CONFIG['CPE_DATABASE_BACKUP_FILE']) if os.path.isfile(CONFIG['DEPRECATED_CPES_BACKUP_FILE']): @@ -553,12 +606,11 @@ def run(full=False, nvd_api_key=None, config_file=''): if os.path.isfile(CONFIG['DATABASE_BACKUP_FILE']): os.remove(CONFIG['DATABASE_BACKUP_FILE']) - # migrate sqlite to mariadb if specified database type is mariadb + # migrate SQLite to MariaDB if specified database type is mariadb if CONFIG['DATABASE']['TYPE'] == 'mariadb': - print('[+] Migrating from sqlite to mariadb (takes around 2 minutes)...') - path_of_migration_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'migrate_sqlite_to_mariadb.sh') - return_code = subprocess.call(['bash', path_of_migration_script, CONFIG['DATABASE_NAME'], CONFIG['cpe_search']['DATABASE_NAME'], CONFIG_FILE], stderr=subprocess.DEVNULL) - if return_code != 0 and False: + print('[+] Migrating from SQLite to MariaDB (takes around 2 minutes)...') + return_code = subprocess.call('./migrate_sqlite_to_mariadb.sh %s %s %s' % (shlex.quote(CONFIG['DATABASE_NAME']), shlex.quote(CONFIG['cpe_search']['DATABASE_NAME']), CONFIG_FILE), shell=True, stderr=subprocess.DEVNULL) + if return_code != 0: raise(Exception('Migration of database failed')) os.remove(MARIADB_BACKUP_FILE) os.remove(CONFIG['DATABASE_NAME']) @@ -575,16 +627,29 @@ def run(full=False, nvd_api_key=None, config_file=''): shutil.move(CONFIG['DATABASE_BACKUP_FILE'], CONFIG['DATABASE_NAME']) print("[+] Restored vulnerability infos from backup") if os.path.isfile(MARIADB_BACKUP_FILE): - restore_call = f'''mariadb -u {CONFIG['DATABASE']['USER']} --password={CONFIG['DATABASE']['PASSWORD']} -h {CONFIG['DATABASE']['HOST']} -P {str(CONFIG['DATABASE']['PORT'])} < {MARIADB_BACKUP_FILE}''' - if QUIET: - return_code = subprocess.call([restore_call], shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + with open(MARIADB_BACKUP_FILE, 'rb') as f: + mariadb_backup_data = f.read() + restore_call = ['mariadb', '-u', CONFIG['DATABASE']['USER'], + '-h', CONFIG['DATABASE']['HOST'], + '-P', str(CONFIG['DATABASE']['PORT'])] + if CONFIG['DATABASE']['PASSWORD']: + restore_call.append('-p') + restore_call.append(CONFIG['DATABASE']['PASSWORD']) + restore_call_run = subprocess.run(restore_call, input=mariadb_backup_data) + if restore_call_run.returncode != 0: + print('[-] Failed to restore MariaDB') else: - return_code = subprocess.call([restore_call], shell=True, stderr=subprocess.DEVNULL) - if return_code != 0: - print('[-] Failed to restore mariadb') + print('[+] Restored MariaDB from backup') + + # Restore failed b/c database is down -> not delete backup file + # check whether database is up by trying to get a connection + try: + get_database_connection(CONFIG['DATABASE'], CONFIG['cpe_search']['DATABASE_NAME']) + except: + print('[!] MariaDB seems to be down. The backup file wasn\'t deleted. To restore manually from the file, run the following command:') + print(' '.join(restore_call+['<', MARIADB_BACKUP_FILE, '&&', 'rm', MARIADB_BACKUP_FILE])) else: - print('[+] Restored mariadb from backup') - os.remove(MARIADB_BACKUP_FILE) + os.remove(MARIADB_BACKUP_FILE) if __name__ == "__main__": diff --git a/web_server.py b/web_server.py index 3b2ef5f..7dea58a 100755 --- a/web_server.py +++ b/web_server.py @@ -3,11 +3,6 @@ import datetime import os import sqlite3 -try: # only use mariadb module if installed - import mariadb -except: - pass - from flask import Flask, request from flask import render_template from cpe_search.database_wrapper_functions import get_database_connection @@ -16,7 +11,7 @@ PROJECT_DIR = os.path.dirname(os.path.realpath(__file__)) STATIC_FOLDER = os.path.join(PROJECT_DIR, os.path.join("web_server_files", "static")) TEMPLATE_FOLDER = os.path.join(PROJECT_DIR, os.path.join("web_server_files", "templates")) -CONFIG_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config_mariadb.json') +CONFIG_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'config.json') DB_URI = 'file:vuln_db?mode=memory&cache=shared' CONNECTION_POOL_SIZE = os.cpu_count() # should be equal to number of cpu cores? (https://dba.stackexchange.com/a/305726) RESULTS_CACHE = {} @@ -78,7 +73,7 @@ def version(): with open('version.txt') as f: search_vulns_version = f.read() - db_modified_ts = os.path.getmtime(config['DATABASE_NAME']) + db_modified_ts = os.path.getmtime(config['cpe_search']['DEPRECATED_CPES_FILE']) db_modified_datetime = datetime.datetime.fromtimestamp(db_modified_ts) result = {'version': search_vulns_version, @@ -106,6 +101,7 @@ def index(): # trigger putting of CPE data into memory with some query conn = sqlite3.connect(DB_URI, uri=True) else: + import mariadb conn_params = { 'user': config['DATABASE']['USER'], 'password': config['DATABASE']['PASSWORD'], @@ -113,8 +109,9 @@ def index(): 'port': config['DATABASE']['PORT'], 'database': config['DATABASE_NAME'] } - pool= mariadb.ConnectionPool(pool_name="search_vulns_pool", pool_size=CONNECTION_POOL_SIZE, **conn_params) - conn=pool.get_connection() + pool = mariadb.ConnectionPool(pool_name="search_vulns_pool", pool_size=CONNECTION_POOL_SIZE, **conn_params) + conn = pool.get_connection() + db_cursor = conn.cursor() search_vulns_call('Sudo 1.8.2', db_cursor=db_cursor, keep_data_in_memory=True, config=config) db_cursor.close() @@ -124,4 +121,6 @@ def index(): if __name__ == '__main__': print('[+] Starting webserver') app.run() - pool.close() + # close pool if exists + if 'pool' in locals(): + pool.close()