Skip to content

Commit

Permalink
Merge pull request #85 from Mytherin/issue71
Browse files Browse the repository at this point in the history
Implement #71: add SSL connection parameters
  • Loading branch information
Mytherin authored Sep 5, 2024
2 parents 6cc9282 + 15ab269 commit 383f192
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/include/mysql_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ struct MySQLConnectionParameters {
uint32_t port = 0;
string unix_socket;
idx_t client_flag = CLIENT_COMPRESS | CLIENT_IGNORE_SIGPIPE | CLIENT_MULTI_STATEMENTS;
unsigned int ssl_mode = SSL_MODE_PREFERRED;
string ssl_ca;
string ssl_ca_path;
string ssl_cert;
string ssl_cipher;
string ssl_crl;
string ssl_crl_path;
string ssl_key;
};

class MySQLUtils {
Expand Down
26 changes: 26 additions & 0 deletions src/mysql_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ unique_ptr<BaseSecret> CreateMySQLSecretFunction(ClientContext &, CreateSecretIn
result->secret_map["port"] = named_param.second.ToString();
} else if (lower_name == "socket") {
result->secret_map["socket"] = named_param.second.ToString();
} else if (lower_name == "ssl_mode") {
result->secret_map["ssl_mode"] = named_param.second.ToString();
} else if (lower_name == "ssl_ca") {
result->secret_map["ssl_ca"] = named_param.second.ToString();
} else if (lower_name == "ssl_capath") {
result->secret_map["ssl_capath"] = named_param.second.ToString();
} else if (lower_name == "ssl_capath") {
result->secret_map["ssl_capath"] = named_param.second.ToString();
} else if (lower_name == "ssl_cert") {
result->secret_map["ssl_cert"] = named_param.second.ToString();
} else if (lower_name == "ssl_cipher") {
result->secret_map["ssl_cipher"] = named_param.second.ToString();
} else if (lower_name == "ssl_crl") {
result->secret_map["ssl_crl"] = named_param.second.ToString();
} else if (lower_name == "ssl_crlpath") {
result->secret_map["ssl_crlpath"] = named_param.second.ToString();
} else if (lower_name == "ssl_key") {
result->secret_map["ssl_key"] = named_param.second.ToString();
} else {
throw InternalException("Unknown named parameter passed to CreateMySQLSecretFunction: " + lower_name);
}
Expand All @@ -55,6 +73,14 @@ void SetMySQLSecretParameters(CreateSecretFunction &function) {
function.named_parameters["user"] = LogicalType::VARCHAR;
function.named_parameters["database"] = LogicalType::VARCHAR;
function.named_parameters["socket"] = LogicalType::VARCHAR;
function.named_parameters["ssl_mode"] = LogicalType::VARCHAR;
function.named_parameters["ssl_ca"] = LogicalType::VARCHAR;
function.named_parameters["ssl_capath"] = LogicalType::VARCHAR;
function.named_parameters["ssl_cert"] = LogicalType::VARCHAR;
function.named_parameters["ssl_cipher"] = LogicalType::VARCHAR;
function.named_parameters["ssl_crl"] = LogicalType::VARCHAR;
function.named_parameters["ssl_crlpath"] = LogicalType::VARCHAR;
function.named_parameters["ssl_key"] = LogicalType::VARCHAR;
}

static void LoadInternal(DatabaseInstance &db) {
Expand Down
9 changes: 8 additions & 1 deletion src/mysql_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,14 @@ static unique_ptr<Catalog> MySQLAttach(StorageExtensionInfo *storage_info, Clien
new_connection_info += AddConnectionOption(kv_secret, "port");
new_connection_info += AddConnectionOption(kv_secret, "database");
new_connection_info += AddConnectionOption(kv_secret, "socket");

new_connection_info += AddConnectionOption(kv_secret, "ssl_mode");
new_connection_info += AddConnectionOption(kv_secret, "ssl_ca");
new_connection_info += AddConnectionOption(kv_secret, "ssl_capath");
new_connection_info += AddConnectionOption(kv_secret, "ssl_cert");
new_connection_info += AddConnectionOption(kv_secret, "ssl_cipher");
new_connection_info += AddConnectionOption(kv_secret, "ssl_crl");
new_connection_info += AddConnectionOption(kv_secret, "ssl_crlpath");
new_connection_info += AddConnectionOption(kv_secret, "ssl_key");
connection_string = new_connection_info + connection_string;
} else if (explicit_secret) {
// secret not found and one was explicitly provided - throw an error
Expand Down
60 changes: 60 additions & 0 deletions src/mysql_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,43 @@ MySQLConnectionParameters MySQLUtils::ParseConnectionParameters(const string &ds
} else {
result.client_flag &= ~CLIENT_COMPRESS;
}
} else if (key == "ssl_mode") {
set_options.insert("ssl_mode");
auto val = StringUtil::Lower(value);
if (val == "disabled") {
result.ssl_mode = SSL_MODE_DISABLED;
} else if (val == "required") {
result.ssl_mode = SSL_MODE_REQUIRED;
} else if (val == "verify_ca") {
result.ssl_mode = SSL_MODE_VERIFY_CA;
} else if (val == "verify_identity") {
result.ssl_mode = SSL_MODE_VERIFY_IDENTITY;
} else if (val == "preferred") {
result.ssl_mode = SSL_MODE_PREFERRED;
} else {
throw InvalidInputException("Invalid dsn - ssl mode must be either disabled, required, verify_ca, verify_identity or preferred - got %s", value);
}
} else if (key == "ssl_ca") {
set_options.insert("ssl_ca");
result.ssl_ca = value;
} else if (key == "ssl_capath") {
set_options.insert("ssl_capath");
result.ssl_ca_path = value;
} else if (key == "ssl_cert") {
set_options.insert("ssl_cert");
result.ssl_cert = value;
} else if (key == "ssl_cipher") {
set_options.insert("ssl_cipher");
result.ssl_cipher = value;
} else if (key == "ssl_crl") {
set_options.insert("ssl_crl");
result.ssl_crl = value;
} else if (key == "ssl_crlpath") {
set_options.insert("ssl_crlpath");
result.ssl_crl_path = value;
} else if (key == "ssl_key") {
set_options.insert("ssl_key");
result.ssl_key = value;
} else {
throw InvalidInputException("Unrecognized configuration parameter \"%s\" "
"- expected options are host, "
Expand Down Expand Up @@ -167,13 +204,36 @@ MySQLConnectionParameters MySQLUtils::ParseConnectionParameters(const string &ds
return result;
}

void SetMySQLOption(MYSQL *mysql, enum mysql_option option, const string &value) {
if (value.empty()) {
return;
}
int rc = mysql_options(mysql, option, value.c_str());
if (rc != 0) {
throw InternalException("Failed to set MySQL option");
}
}

MYSQL *MySQLUtils::Connect(const string &dsn) {
MYSQL *mysql = mysql_init(NULL);
if (!mysql) {
throw IOException("Failure in mysql_init");
}
MYSQL *result;
auto config = ParseConnectionParameters(dsn);
// set SSL options (if any)
if (config.ssl_mode != SSL_MODE_PREFERRED) {
mysql_options(mysql, MYSQL_OPT_SSL_MODE, &config.ssl_mode);
}
SetMySQLOption(mysql, MYSQL_OPT_SSL_CA, config.ssl_ca);
SetMySQLOption(mysql, MYSQL_OPT_SSL_CAPATH, config.ssl_ca_path);
SetMySQLOption(mysql, MYSQL_OPT_SSL_CERT, config.ssl_cert);
SetMySQLOption(mysql, MYSQL_OPT_SSL_CIPHER, config.ssl_cipher);
SetMySQLOption(mysql, MYSQL_OPT_SSL_CRL, config.ssl_crl);
SetMySQLOption(mysql, MYSQL_OPT_SSL_CRLPATH, config.ssl_crl_path);
SetMySQLOption(mysql, MYSQL_OPT_SSL_KEY, config.ssl_key);

// get connection options
const char *host = config.host.empty() ? nullptr : config.host.c_str();
const char *user = config.user.empty() ? nullptr : config.user.c_str();
const char *passwd = config.passwd.empty() ? nullptr : config.passwd.c_str();
Expand Down
17 changes: 17 additions & 0 deletions test/sql/attach_ssl.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# name: test/sql/attach_ssl.test
# description: ATTACH with SSL parameters
# group: [sql]

require mysql_scanner

require-env MYSQL_TEST_DATABASE_AVAILABLE

# invalid ssl mode
statement error
ATTACH 'host=localhost user=root port=0 database=mysql ssl_mode=xxx' AS simple (TYPE MYSQL_SCANNER)
----
ssl mode must be either

# don't ask me why this works
statement ok
ATTACH 'host=localhost user=root port=0 database=mysql ssl_mode=required ssl_ca=/xxx/ ssl_capath=/xxx/ ssl_cert=/xxx/ ssl_cipher=/xxx/ ssl_crl=/xxx/ ssl_crlpath=/xxx/ ssl_key=/xxx/' AS simple (TYPE MYSQL_SCANNER)

0 comments on commit 383f192

Please sign in to comment.