diff --git a/src/include/mysql_utils.hpp b/src/include/mysql_utils.hpp index 4aa9c73..e951613 100644 --- a/src/include/mysql_utils.hpp +++ b/src/include/mysql_utils.hpp @@ -38,6 +38,14 @@ struct MySQLConnectionParameters { uint32_t port = 0; string unix_socket; idx_t client_flag = CLIENT_COMPRESS | CLIENT_IGNORE_SIGPIPE | CLIENT_MULTI_STATEMENTS; + unsigned int ssl_mode = SSL_MODE_PREFERRED; + string ssl_ca; + string ssl_ca_path; + string ssl_cert; + string ssl_cipher; + string ssl_crl; + string ssl_crl_path; + string ssl_key; }; class MySQLUtils { diff --git a/src/mysql_extension.cpp b/src/mysql_extension.cpp index b1736c1..9f41388 100644 --- a/src/mysql_extension.cpp +++ b/src/mysql_extension.cpp @@ -38,6 +38,24 @@ unique_ptr CreateMySQLSecretFunction(ClientContext &, CreateSecretIn result->secret_map["port"] = named_param.second.ToString(); } else if (lower_name == "socket") { result->secret_map["socket"] = named_param.second.ToString(); + } else if (lower_name == "ssl_mode") { + result->secret_map["ssl_mode"] = named_param.second.ToString(); + } else if (lower_name == "ssl_ca") { + result->secret_map["ssl_ca"] = named_param.second.ToString(); + } else if (lower_name == "ssl_capath") { + result->secret_map["ssl_capath"] = named_param.second.ToString(); + } else if (lower_name == "ssl_capath") { + result->secret_map["ssl_capath"] = named_param.second.ToString(); + } else if (lower_name == "ssl_cert") { + result->secret_map["ssl_cert"] = named_param.second.ToString(); + } else if (lower_name == "ssl_cipher") { + result->secret_map["ssl_cipher"] = named_param.second.ToString(); + } else if (lower_name == "ssl_crl") { + result->secret_map["ssl_crl"] = named_param.second.ToString(); + } else if (lower_name == "ssl_crlpath") { + result->secret_map["ssl_crlpath"] = named_param.second.ToString(); + } else if (lower_name == "ssl_key") { + result->secret_map["ssl_key"] = named_param.second.ToString(); } else { throw InternalException("Unknown named parameter passed to CreateMySQLSecretFunction: " + lower_name); } @@ -55,6 +73,14 @@ void SetMySQLSecretParameters(CreateSecretFunction &function) { function.named_parameters["user"] = LogicalType::VARCHAR; function.named_parameters["database"] = LogicalType::VARCHAR; function.named_parameters["socket"] = LogicalType::VARCHAR; + function.named_parameters["ssl_mode"] = LogicalType::VARCHAR; + function.named_parameters["ssl_ca"] = LogicalType::VARCHAR; + function.named_parameters["ssl_capath"] = LogicalType::VARCHAR; + function.named_parameters["ssl_cert"] = LogicalType::VARCHAR; + function.named_parameters["ssl_cipher"] = LogicalType::VARCHAR; + function.named_parameters["ssl_crl"] = LogicalType::VARCHAR; + function.named_parameters["ssl_crlpath"] = LogicalType::VARCHAR; + function.named_parameters["ssl_key"] = LogicalType::VARCHAR; } static void LoadInternal(DatabaseInstance &db) { diff --git a/src/mysql_storage.cpp b/src/mysql_storage.cpp index fd3427e..caf70aa 100644 --- a/src/mysql_storage.cpp +++ b/src/mysql_storage.cpp @@ -90,7 +90,14 @@ static unique_ptr MySQLAttach(StorageExtensionInfo *storage_info, Clien new_connection_info += AddConnectionOption(kv_secret, "port"); new_connection_info += AddConnectionOption(kv_secret, "database"); new_connection_info += AddConnectionOption(kv_secret, "socket"); - + new_connection_info += AddConnectionOption(kv_secret, "ssl_mode"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_ca"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_capath"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_cert"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_cipher"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_crl"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_crlpath"); + new_connection_info += AddConnectionOption(kv_secret, "ssl_key"); connection_string = new_connection_info + connection_string; } else if (explicit_secret) { // secret not found and one was explicitly provided - throw an error diff --git a/src/mysql_utils.cpp b/src/mysql_utils.cpp index 843b484..b764b3b 100644 --- a/src/mysql_utils.cpp +++ b/src/mysql_utils.cpp @@ -125,6 +125,43 @@ MySQLConnectionParameters MySQLUtils::ParseConnectionParameters(const string &ds } else { result.client_flag &= ~CLIENT_COMPRESS; } + } else if (key == "ssl_mode") { + set_options.insert("ssl_mode"); + auto val = StringUtil::Lower(value); + if (val == "disabled") { + result.ssl_mode = SSL_MODE_DISABLED; + } else if (val == "required") { + result.ssl_mode = SSL_MODE_REQUIRED; + } else if (val == "verify_ca") { + result.ssl_mode = SSL_MODE_VERIFY_CA; + } else if (val == "verify_identity") { + result.ssl_mode = SSL_MODE_VERIFY_IDENTITY; + } else if (val == "preferred") { + result.ssl_mode = SSL_MODE_PREFERRED; + } else { + throw InvalidInputException("Invalid dsn - ssl mode must be either disabled, required, verify_ca, verify_identity or preferred - got %s", value); + } + } else if (key == "ssl_ca") { + set_options.insert("ssl_ca"); + result.ssl_ca = value; + } else if (key == "ssl_capath") { + set_options.insert("ssl_capath"); + result.ssl_ca_path = value; + } else if (key == "ssl_cert") { + set_options.insert("ssl_cert"); + result.ssl_cert = value; + } else if (key == "ssl_cipher") { + set_options.insert("ssl_cipher"); + result.ssl_cipher = value; + } else if (key == "ssl_crl") { + set_options.insert("ssl_crl"); + result.ssl_crl = value; + } else if (key == "ssl_crlpath") { + set_options.insert("ssl_crlpath"); + result.ssl_crl_path = value; + } else if (key == "ssl_key") { + set_options.insert("ssl_key"); + result.ssl_key = value; } else { throw InvalidInputException("Unrecognized configuration parameter \"%s\" " "- expected options are host, " @@ -167,6 +204,16 @@ MySQLConnectionParameters MySQLUtils::ParseConnectionParameters(const string &ds return result; } +void SetMySQLOption(MYSQL *mysql, enum mysql_option option, const string &value) { + if (value.empty()) { + return; + } + int rc = mysql_options(mysql, option, value.c_str()); + if (rc != 0) { + throw InternalException("Failed to set MySQL option"); + } +} + MYSQL *MySQLUtils::Connect(const string &dsn) { MYSQL *mysql = mysql_init(NULL); if (!mysql) { @@ -174,6 +221,19 @@ MYSQL *MySQLUtils::Connect(const string &dsn) { } MYSQL *result; auto config = ParseConnectionParameters(dsn); + // set SSL options (if any) + if (config.ssl_mode != SSL_MODE_PREFERRED) { + mysql_options(mysql, MYSQL_OPT_SSL_MODE, &config.ssl_mode); + } + SetMySQLOption(mysql, MYSQL_OPT_SSL_CA, config.ssl_ca); + SetMySQLOption(mysql, MYSQL_OPT_SSL_CAPATH, config.ssl_ca_path); + SetMySQLOption(mysql, MYSQL_OPT_SSL_CERT, config.ssl_cert); + SetMySQLOption(mysql, MYSQL_OPT_SSL_CIPHER, config.ssl_cipher); + SetMySQLOption(mysql, MYSQL_OPT_SSL_CRL, config.ssl_crl); + SetMySQLOption(mysql, MYSQL_OPT_SSL_CRLPATH, config.ssl_crl_path); + SetMySQLOption(mysql, MYSQL_OPT_SSL_KEY, config.ssl_key); + + // get connection options const char *host = config.host.empty() ? nullptr : config.host.c_str(); const char *user = config.user.empty() ? nullptr : config.user.c_str(); const char *passwd = config.passwd.empty() ? nullptr : config.passwd.c_str(); diff --git a/test/sql/attach_ssl.test b/test/sql/attach_ssl.test new file mode 100644 index 0000000..1446818 --- /dev/null +++ b/test/sql/attach_ssl.test @@ -0,0 +1,17 @@ +# name: test/sql/attach_ssl.test +# description: ATTACH with SSL parameters +# group: [sql] + +require mysql_scanner + +require-env MYSQL_TEST_DATABASE_AVAILABLE + +# invalid ssl mode +statement error +ATTACH 'host=localhost user=root port=0 database=mysql ssl_mode=xxx' AS simple (TYPE MYSQL_SCANNER) +---- +ssl mode must be either + +# don't ask me why this works +statement ok +ATTACH 'host=localhost user=root port=0 database=mysql ssl_mode=required ssl_ca=/xxx/ ssl_capath=/xxx/ ssl_cert=/xxx/ ssl_cipher=/xxx/ ssl_crl=/xxx/ ssl_crlpath=/xxx/ ssl_key=/xxx/' AS simple (TYPE MYSQL_SCANNER)