Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Usability enhancements #16

Merged
merged 5 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[profile.ci]
# Do not cancel the test run on the first failure.
fail-fast = false

[profile.ci.junit]
path = "junit.xml"
33 changes: 25 additions & 8 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,38 @@ name: Rust

on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]

env:
CARGO_TERM_COLOR: always

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose
- uses: actions/checkout@v4

- uses: dtolnay/rust-toolchain@stable
with:
components: llvm-tools-preview, rustfmt, clippy

- uses: taiki-e/install-action@cargo-llvm-cov
- uses: taiki-e/install-action@nextest

- name: Lint (clippy)
run: cargo clippy --all-features --all-targets

- name: Run tests (with coverage)
run: cargo llvm-cov nextest --profile ci

- name: Generate coverage report
run: cargo llvm-cov report --lcov --output-path lcov.info

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
/Cargo.lock
lcov.info
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ Pull requests are welcome. For major changes, please open an issue first to disc

Please make sure to update tests as appropriate.

### Useful commands

```bash
cargo fmt
cargo clippy --all-features --all-targets
# if you have nextest installed, you can run tests with:
cargo nextest run --profile ci
# otherwise regular cargo test will work
cargo test
```

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
6 changes: 5 additions & 1 deletion examples/local/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ async fn main() {
let keys = vec![DecodingKey::from_rsa_pem(include_bytes!("jwt.key.pub")).unwrap()];
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&["https://example.com"]);
let decoder = LocalDecoder::new(keys, validation);
let decoder = LocalDecoder::builder()
.keys(keys)
.validation(validation)
.build()
.unwrap();
let state = AppState {
decoder: JwtDecoderState {
decoder: Arc::new(decoder),
Expand Down
15 changes: 7 additions & 8 deletions examples/remote/remote.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use axum::{extract::FromRef, routing::get, Json, Router};
use axum_jwt_auth::{Claims, JwtDecoderState, RemoteJwksDecoderBuilder};
use axum_jwt_auth::{Claims, JwtDecoderState, RemoteJwksDecoder};
use jsonwebtoken::{Algorithm, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
Expand Down Expand Up @@ -73,13 +73,12 @@ async fn main() {
validation.set_issuer(&["your-issuer"]);

// Create a decoder pointing to the JWKS endpoint
let decoder = Arc::new(
RemoteJwksDecoderBuilder::default()
.jwks_url("http://127.0.0.1:3000/.well-known/jwks.json".to_string())
.validation(validation)
.build()
.expect("Failed to build JWKS decoder"),
);
let decoder = RemoteJwksDecoder::builder()
.jwks_url("http://127.0.0.1:3000/.well-known/jwks.json".to_string())
.validation(validation)
.build()
.expect("Failed to build JWKS decoder");
let decoder = Arc::new(decoder);

// Start background task to periodically refresh JWKS
let decoder_clone = decoder.clone();
Expand Down
179 changes: 150 additions & 29 deletions src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use async_trait::async_trait;
use axum::extract::FromRef;
use axum::http::StatusCode;
use axum::response::Response;
Expand All @@ -6,6 +7,7 @@ use axum::{http::request::Parts, response::IntoResponse};
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use jsonwebtoken::errors::ErrorKind;
use serde::de::DeserializeOwned;
use serde::Deserialize;

Expand All @@ -15,6 +17,25 @@ use crate::Decoder;
#[derive(Debug, Deserialize)]
pub struct Claims<T>(pub T);

/// Trait for extracting tokens from request parts
#[async_trait]
pub trait TokenExtractor {
async fn extract_token(parts: &mut Parts) -> Result<String, AuthError>;
}

/// Default implementation using Bearer token
pub struct BearerTokenExtractor;

#[async_trait]
impl TokenExtractor for BearerTokenExtractor {
async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
let auth: TypedHeader<Authorization<Bearer>> =
parts.extract().await.map_err(|_| AuthError::MissingToken)?;

Ok(auth.token().to_string())
}
}

impl<S, T> axum::extract::FromRequestParts<S> for Claims<T>
where
JwtDecoderState<T>: FromRef<S>,
Expand All @@ -24,51 +45,97 @@ where
type Rejection = AuthError;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// `TypedHeader<Authorization<Bearer>>` extracts the auth token
let auth: TypedHeader<Authorization<Bearer>> = parts
.extract()
.await
.map_err(|_| Self::Rejection::MissingToken)?;
// TODO: Allow for custom token extractors?
let token = BearerTokenExtractor::extract_token(parts).await?;

let state = JwtDecoderState::from_ref(state);
// `JwtDecoder::decode` decodes the token
let token_data = state
.decoder
.clone()
.decode(auth.token())
.decode(&token)
.await
.map_err(|e| match e {
crate::Error::Jwt(e) => match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
Self::Rejection::ExpiredToken
}
jsonwebtoken::errors::ErrorKind::InvalidSignature => {
Self::Rejection::InvalidSignature
}
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
Self::Rejection::InvalidAudience
}
_ => Self::Rejection::InvalidToken,
},
_ => Self::Rejection::InternalError,
})?;
.map_err(map_jwt_error)?;

Ok(Claims(token_data.claims))
}
}

#[derive(Debug, thiserror::Error)]
/// Maps JWT errors to AuthError
fn map_jwt_error(err: crate::Error) -> AuthError {
match err {
crate::Error::Jwt(e) => match e.kind() {
ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
ErrorKind::InvalidSignature => AuthError::InvalidSignature,
ErrorKind::InvalidAudience => AuthError::InvalidAudience,
ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
ErrorKind::InvalidToken => AuthError::InvalidToken,
ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
ErrorKind::InvalidSubject => AuthError::InvalidSubject,
ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
ErrorKind::MissingRequiredClaim(claim) => {
AuthError::MissingRequiredClaim(claim.to_string())
}
_ => AuthError::InternalError,
},
_ => AuthError::InternalError,
}
}

/// An enum representing the possible errors that can occur when authenticating a request.
/// These are sourced from the `jsonwebtoken` crate and defined here to implement `IntoResponse` for
/// use in the `axum` framework.
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum AuthError {
/// When the token is invalid
#[error("Invalid token")]
InvalidToken,
#[error("Missing token")]
MissingToken,
#[error("Expired token")]
ExpiredToken,

/// When the signature is invalid
#[error("Invalid signature")]
InvalidSignature,

// Validation errors
/// When a claim required by the validation is not present
#[error("Missing required claim: {0}")]
MissingRequiredClaim(String),

/// When a token's `exp` claim indicates that it has expired
#[error("Expired signature")]
ExpiredSignature,

/// When a token's `iss` claim does not match the expected issuer
#[error("Invalid issuer")]
InvalidIssuer,

/// When a token's `aud` claim does not match one of the expected audience values
#[error("Invalid audience")]
InvalidAudience,

/// When a token's `sub` claim does not match one of the expected subject values
#[error("Invalid subject")]
InvalidSubject,

/// When a token's `nbf` claim represents a time in the future
#[error("Immature signature")]
ImmatureSignature,

/// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
/// used doesn't match the alg requested
#[error("Invalid algorithm")]
InvalidAlgorithm,

/// When the Validation struct does not contain at least 1 algorithm
#[error("Missing algorithm")]
MissingAlgorithm,

/// When the request is missing a token
#[error("Missing token")]
MissingToken,

/// When an internal error occurs that doesn't fit into the other categories.
/// This is a catch-all error for any unexpected errors that occur such as
/// network errors, decoding errors, and cryptographic errors.
#[error("Internal error")]
InternalError,
}
Expand All @@ -77,10 +144,18 @@ impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, msg) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"),
AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
AuthError::MissingRequiredClaim(_) => {
(StatusCode::UNAUTHORIZED, "Missing required claim")
}
AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
};

Expand All @@ -98,3 +173,49 @@ impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
state.decoder.clone()
}
}

#[cfg(test)]
mod tests {

use super::*;
use axum::body::Body;
use axum::extract::Request;

#[tokio::test]
async fn test_map_jwt_error() {
use jsonwebtoken::errors::Error as JwtError;

let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
assert!(matches!(auth_error, AuthError::ExpiredSignature));
}

#[tokio::test]
async fn test_bearer_token_extractor() {
// Valid token
let req = Request::builder()
.header("Authorization", "Bearer test_token")
.body(Body::empty())
.unwrap();

let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
assert!(token.is_ok());
assert_eq!(token.unwrap(), "test_token");

// Invalid token
let req = Request::builder()
.header("Authorization", "Not a bearer token")
.body(Body::empty())
.unwrap();

let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
assert!(token.is_err());
assert_eq!(token.unwrap_err(), AuthError::MissingToken);

// Missing token
let req = Request::builder().body(Body::empty()).unwrap();
let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
assert!(token.is_err());
assert_eq!(token.unwrap_err(), AuthError::MissingToken);
}
}
15 changes: 13 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ pub use crate::remote::{
pub enum Error {
#[error("JWT key not found (kid: {0:?})")]
KeyNotFound(Option<String>),

#[error("Configuration error: {0}")]
Configuration(String),

#[error("JWT error: {0}")]
Jwt(#[from] jsonwebtoken::errors::Error),

#[error("HTTP request error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("JWKS refresh failed: {0}")]
JwksRefresh(String),

#[error("JWKS refresh failed after {retry_count} attempts: {message}")]
JwksRefresh {
message: String,
retry_count: usize,
#[source]
source: Option<Box<dyn std::error::Error + Send + Sync>>,
},
}

/// A generic trait for decoding JWT tokens.
Expand Down
Loading
Loading