Skip to content

Commit

Permalink
build: allow to use rustls instead of native-tls
Browse files Browse the repository at this point in the history
* This is used in an effort to remove all dependencies to openssl.
  Which could be interesting in embedded system or on environment
  which is difficult to know on which openssl version the software
  will run it and breaks deployments.
* It introduces two compiler feature flags which are `tokio-rustls-runtime`
  and `async-std-rustls-runtime` that have the same meaning as
  `tokio-runtime` and `async-std-runtime` except that they use rustls.
* There is a safe guard, if we enable both runtimes, this is the
  native-tls ones that are used to keep consistent with the current
  behaviour.

Signed-off-by: Florentin Dubois <[email protected]>
  • Loading branch information
FlorentinDUBOIS committed Jul 31, 2023
1 parent 3964ed7 commit 0cf2fdc
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 51 deletions.
16 changes: 11 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ regex = "^1.9.1"
bit-vec = "^0.6.3"
futures = "^0.3.28"
futures-io = "^0.3.28"
native-tls = "^0.2.11"
native-tls = { version = "^0.2.11", optional = true }
rustls = { version = "^0.21.5", optional = true }
webpki-roots = { version = "^0.24.0", optional = true }
pem = "^3.0.0"
tokio = { version = "^1.29.1", features = ["rt", "net", "time"], optional = true }
tokio-util = { version = "^0.7.8", features = ["codec"], optional = true }
tokio-rustls = { version = "^0.24.1", optional = true }
tokio-native-tls = { version = "^0.3.1", optional = true }
async-std = {version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
async-std = { version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
asynchronous-codec = { version = "^0.6.2", optional = true }
async-rustls = { version = "^0.4.0", optional = true }
async-native-tls = { version = "^0.5.0", optional = true }
lz4 = { version = "^1.24.0", optional = true }
flate2 = { version = "^1.0.26", optional = true }
Expand All @@ -49,7 +53,7 @@ serde_json = { version = "^1.0.103", optional = true }
tracing = { version = "^0.1.37", optional = true }
async-trait = "^0.1.72"
data-url = { version = "^0.3.0", optional = true }
uuid = {version = "^1.4.1", features = ["v4", "fast-rng"] }
uuid = { version = "^1.4.1", features = ["v4", "fast-rng"] }

[dev-dependencies]
serde = { version = "^1.0.175", features = ["derive"] }
Expand All @@ -64,8 +68,10 @@ protobuf-src = { version = "1.1.0", optional = true }
[features]
default = [ "compression", "tokio-runtime", "async-std-runtime", "auth-oauth2"]
compression = [ "lz4", "flate2", "zstd", "snap" ]
tokio-runtime = [ "tokio", "tokio-util", "tokio-native-tls" ]
async-std-runtime = [ "async-std", "asynchronous-codec", "async-native-tls" ]
tokio-runtime = [ "tokio", "tokio-util", "native-tls", "tokio-native-tls" ]
tokio-rustls-runtime = ["tokio", "tokio-util", "tokio-rustls", "rustls", "webpki-roots" ]
async-std-runtime = [ "async-std", "asynchronous-codec", "native-tls", "async-native-tls" ]
async-std-rustls-runtime = ["async-std", "asynchronous-codec", "async-rustls", "rustls", "webpki-roots" ]
auth-oauth2 = [ "openidconnect", "oauth2", "serde", "serde_json", "data-url" ]
telemetry = ["tracing"]
protobuf-src = ["dep:protobuf-src"]
150 changes: 145 additions & 5 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,18 @@ use futures::{
task::{Context, Poll},
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
};
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
use native_tls::Certificate;
use proto::MessageIdData;
use rand::{seq::SliceRandom, thread_rng};
#[cfg(all(
any(
feature = "tokio-rustls-runtime",
feature = "async-std-rustls-runtime"
),
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
))]
use rustls::Certificate;
use url::Url;
use uuid::Uuid;

Expand Down Expand Up @@ -934,7 +943,69 @@ impl<Exe: Executor> Connection<Exe> {
.await
}
}
#[cfg(not(feature = "tokio-runtime"))]
#[cfg(all(feature = "tokio-rustls-runtime", not(feature = "tokio-runtime")))]
ExecutorKind::Tokio => {
if tls {
let stream = tokio::net::TcpStream::connect(&address).await?;
let mut root_store = rustls::RootCertStore::empty();
for certificate in certificate_chain {
root_store.add(certificate)?;
}

let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
vec![],
|mut acc, trust_anchor| {
acc.push(
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
trust_anchor.subject,
trust_anchor.spki,
trust_anchor.name_constraints,
),
);
acc
},
);

root_store.add_server_trust_anchors(trust_anchors.into_iter());
let config = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()?
.with_root_certificates(root_store)
.with_no_client_auth();

let cx = tokio_rustls::TlsConnector::from(Arc::new(config));
let stream = cx
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
.await
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;

Connection::connect(
connection_id,
stream,
auth,
proxy_to_broker_url,
executor,
operation_timeout,
)
.await
} else {
let stream = tokio::net::TcpStream::connect(&address)
.await
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;

Connection::connect(
connection_id,
stream,
auth,
proxy_to_broker_url,
executor,
operation_timeout,
)
.await
}
}
#[cfg(all(not(feature = "tokio-runtime"), not(feature = "tokio-rustls-runtime")))]
ExecutorKind::Tokio => {
unimplemented!("the tokio-runtime cargo feature is not active");
}
Expand Down Expand Up @@ -980,7 +1051,75 @@ impl<Exe: Executor> Connection<Exe> {
.await
}
}
#[cfg(not(feature = "async-std-runtime"))]
#[cfg(all(
feature = "async-std-rustls-runtime",
not(feature = "async-std-runtime")
))]
ExecutorKind::AsyncStd => {
if tls {
let stream = async_std::net::TcpStream::connect(&address).await?;
let mut root_store = rustls::RootCertStore::empty();
for certificate in certificate_chain {
root_store.add(certificate)?;
}

let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
vec![],
|mut acc, trust_anchor| {
acc.push(
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
trust_anchor.subject,
trust_anchor.spki,
trust_anchor.name_constraints,
),
);
acc
},
);

root_store.add_server_trust_anchors(trust_anchors.into_iter());
let config = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()?
.with_root_certificates(root_store)
.with_no_client_auth();

let connector = async_rustls::TlsConnector::from(Arc::new(config));
let stream = connector
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
.await
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;

Connection::connect(
connection_id,
stream,
auth,
proxy_to_broker_url,
executor,
operation_timeout,
)
.await
} else {
let stream = async_std::net::TcpStream::connect(&address)
.await
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;

Connection::connect(
connection_id,
stream,
auth,
proxy_to_broker_url,
executor,
operation_timeout,
)
.await
}
}
#[cfg(all(
not(feature = "async-std-runtime"),
not(feature = "async-std-rustls-runtime")
))]
ExecutorKind::AsyncStd => {
unimplemented!("the async-std-runtime cargo feature is not active");
}
Expand Down Expand Up @@ -1623,16 +1762,17 @@ mod tests {
use uuid::Uuid;

use super::{Connection, Receiver};
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
use crate::TokioExecutor;
use crate::{
authentication::Authentication,
error::{AuthenticationError, SharedError},
message::{BaseCommand, Codec, Message},
proto::{AuthData, CommandAuthChallenge, CommandAuthResponse, CommandConnected},
TokioExecutor,
};

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn receiver_auth_challenge_test() {
let (message_tx, message_rx) = mpsc::unbounded();
let (tx, _) = mpsc::unbounded();
Expand Down Expand Up @@ -1690,7 +1830,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn connection_auth_challenge_test() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

Expand Down
21 changes: 20 additions & 1 deletion src/connection_manager.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
use std::{collections::HashMap, sync::Arc, time::Duration};

use futures::{channel::oneshot, lock::Mutex};
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
use native_tls::Certificate;
use rand::Rng;
#[cfg(all(
any(
feature = "tokio-rustls-runtime",
feature = "async-std-rustls-runtime"
),
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
))]
use rustls::Certificate;
use url::Url;

use crate::{connection::Connection, error::ConnectionError, executor::Executor};
Expand Down Expand Up @@ -153,10 +162,20 @@ impl<Exe: Executor> ConnectionManager<Exe> {
.iter()
.rev()
{
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
v.push(
Certificate::from_der(&cert.contents())
Certificate::from_der(cert.contents())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
);

#[cfg(all(
any(
feature = "tokio-rustls-runtime",
feature = "async-std-rustls-runtime"
),
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
))]
v.push(Certificate(cert.contents().to_vec()));
}
v
}
Expand Down
16 changes: 8 additions & 8 deletions src/consumer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,11 @@ mod tests {
};
use log::LevelFilter;
use regex::Regex;
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
use tokio::time::timeout;

use super::*;
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
use crate::executor::TokioExecutor;
use crate::{
consumer::initial_position::InitialPosition, producer, proto, tests::TEST_LOGGER,
Expand Down Expand Up @@ -476,7 +476,7 @@ mod tests {
tag: "multi_consumer",
};
#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn multi_consumer() {
let _result = log::set_logger(&MULTI_LOGGER);
log::set_max_level(LevelFilter::Debug);
Expand Down Expand Up @@ -567,7 +567,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn consumer_dropped_with_lingering_acks() {
use rand::{distributions::Alphanumeric, Rng};
let _result = log::set_logger(&TEST_LOGGER);
Expand Down Expand Up @@ -664,7 +664,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn dead_letter_queue() {
let _result = log::set_logger(&TEST_LOGGER);
log::set_max_level(LevelFilter::Debug);
Expand Down Expand Up @@ -738,7 +738,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn failover() {
let _result = log::set_logger(&MULTI_LOGGER);
log::set_max_level(LevelFilter::Debug);
Expand Down Expand Up @@ -798,7 +798,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn seek_single_consumer() {
let _result = log::set_logger(&MULTI_LOGGER);
log::set_max_level(LevelFilter::Debug);
Expand Down Expand Up @@ -917,7 +917,7 @@ mod tests {
}

#[tokio::test]
#[cfg(feature = "tokio-runtime")]
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn schema_test() {
#[derive(Serialize, Deserialize)]
struct TestData {
Expand Down
Loading

0 comments on commit 0cf2fdc

Please sign in to comment.