Skip to content

Commit

Permalink
feat: replace OnceLocks with axum's State
Browse files Browse the repository at this point in the history
  • Loading branch information
uku3lig committed Oct 25, 2024
1 parent ccbe318 commit 1619264
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 44 deletions.
22 changes: 9 additions & 13 deletions src/discord.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::{Arc, OnceLock};
use std::sync::Arc;

use anyhow::Context;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::Redirect;
Expand All @@ -9,10 +8,10 @@ use serde::{Deserialize, Serialize};
use serenity::all::{CreateInvite, Http};

use crate::config::EnvCfg;
use crate::AppState;
use crate::{util::IntoAppError, RouteResponse};

const VERIF_URL: &str = "https://challenges.cloudflare.com/turnstile/v0/siteverify";
static SERENITY_HTTP: OnceLock<Http> = OnceLock::new();

#[derive(Debug, Serialize, Deserialize)]
pub struct TurnstileData {
Expand All @@ -26,25 +25,21 @@ pub struct TurnstileResponse {
error_codes: Vec<String>,
}

pub async fn init_bot(config: &EnvCfg) -> anyhow::Result<()> {
pub async fn init_bot(config: &EnvCfg) -> anyhow::Result<Http> {
let http = Http::new(&config.bot_token);

let user = http.get_current_user().await?;
SERENITY_HTTP.set(http).unwrap();

tracing::info!("successfully logged in to discord bot {}!", user.name);

Ok(())
Ok(http)
}

pub async fn generate_invite(
Query(data): Query<TurnstileData>,
State(config): State<Arc<EnvCfg>>,
State(state): State<Arc<AppState>>,
) -> RouteResponse<impl IntoResponse> {
let http = SERENITY_HTTP.get().context("bot token not set")?;

let body = [
("secret", &config.turnstile_secret),
("secret", &state.config.turnstile_secret),
("response", &data.token),
];
let request = crate::CLIENT.post(VERIF_URL).form(&body).build()?;
Expand All @@ -61,9 +56,10 @@ pub async fn generate_invite(
return (StatusCode::BAD_REQUEST, message.as_str()).into_app_err();
}

let invite = config
let invite = state
.config
.channel_id
.create_invite(http, CreateInvite::new().max_uses(1))
.create_invite(&state.http, CreateInvite::new().max_uses(1))
.await?;

let link = format!("https://discord.com/invite/{}", invite.code);
Expand Down
39 changes: 22 additions & 17 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod metrics;
mod tiers;
mod util;

use std::sync::{Arc, LazyLock, OnceLock};
use std::sync::{Arc, LazyLock};

use axum::routing::get;
use axum::{middleware, Router};
Expand Down Expand Up @@ -36,7 +36,11 @@ static CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
.unwrap()
});

static CACHE: OnceLock<cache::Storage> = OnceLock::new();
struct AppState {
config: EnvCfg,
cache: cache::Storage,
http: serenity::http::Http,
}

type RouteResponse<T> = Result<T, AppError>;

Expand All @@ -48,34 +52,39 @@ async fn main() -> anyhow::Result<()> {

tracing_subscriber::fmt::init();

let config = Arc::new(envy::from_env::<EnvCfg>()?);
let config = envy::from_env::<EnvCfg>()?;
let metrics_addr = config.metrics_socket_addr.clone();

tokio::try_join!(
start_main_app(config.clone()),
metrics::start_metrics_app(&config.metrics_socket_addr)
start_main_app(config),
metrics::start_metrics_app(metrics_addr)
)?;

tracing::info!("shutting down!");

Ok(())
}

async fn start_main_app(config: Arc<EnvCfg>) -> anyhow::Result<()> {
async fn start_main_app(config: EnvCfg) -> anyhow::Result<()> {
let http = discord::init_bot(&config).await?;
let cache = Storage::new(&config.redis_url).await?;

let state = Arc::new(AppState {
config,
cache,
http,
});

let app = Router::new()
.merge(downloads::router())
.merge(tiers::router())
.route("/generate_invite", get(discord::generate_invite))
.fallback(|| async { (StatusCode::NOT_FOUND, "Not Found") })
.layer(TraceLayer::new_for_http().on_request(|_: &_, _: &_| {}))
.layer(middleware::from_fn(metrics::track))
.with_state(config.clone());

discord::init_bot(&config).await?;
.with_state(state.clone());

let storage = Storage::new(&config.redis_url).await?;
CACHE.set(storage).unwrap();

let listener = tokio::net::TcpListener::bind(&config.socket_addr).await?;
let listener = tokio::net::TcpListener::bind(&state.config.socket_addr).await?;
tracing::info!("main app listening on {}", listener.local_addr()?);

axum::serve(listener, app)
Expand All @@ -85,7 +94,3 @@ async fn start_main_app(config: Arc<EnvCfg>) -> anyhow::Result<()> {
.await
.map_err(anyhow::Error::from)
}

fn get_cache() -> &'static Storage {
CACHE.get().unwrap()
}
2 changes: 1 addition & 1 deletion src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const EXPONENTIAL_SECONDS: &[f64] = &[
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
];

pub async fn start_metrics_app(socket_addr: &str) -> anyhow::Result<()> {
pub async fn start_metrics_app(socket_addr: String) -> anyhow::Result<()> {
let handle = PrometheusBuilder::new()
.set_buckets_for_metric(
Matcher::Suffix("duration_seconds".to_string()),
Expand Down
37 changes: 24 additions & 13 deletions src/tiers.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::{collections::HashMap, fmt::Display, time::Instant};

use axum::{extract::Path, response::IntoResponse, routing::get, Json, Router};
use std::{collections::HashMap, fmt::Display, sync::Arc, time::Instant};

use axum::{
extract::{Path, State},
response::IntoResponse,
routing::get,
Json, Router,
};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::{get_cache, RouteResponse};
use crate::{AppState, RouteResponse};

const MCTIERS_REQS_KEY: &str = "api_rs_mctiers_reqs_total";
const MCTIERS_REQ_DURATION_KEY: &str = "api_rs_mctiers_req_duration_seconds";
Expand Down Expand Up @@ -58,7 +63,7 @@ struct MojangUUID {

// === Routes ===

pub fn router<S: Clone + Send + Sync + 'static>() -> Router<S> {
pub fn router() -> Router<Arc<AppState>> {
let router = Router::new()
.route("/all", get(get_all))
.route("/profile/:uuid", get(get_tier))
Expand All @@ -67,19 +72,22 @@ pub fn router<S: Clone + Send + Sync + 'static>() -> Router<S> {
Router::new().nest("/tiers", router)
}

pub async fn get_tier(Path(uuid): Path<Uuid>) -> RouteResponse<impl IntoResponse> {
pub async fn get_tier(
Path(uuid): Path<Uuid>,
State(state): State<Arc<AppState>>,
) -> RouteResponse<impl IntoResponse> {
// uuid version 4 and ietf variant, used by UUID#randomUUID
if uuid.get_version() != Some(uuid::Version::Random)
|| uuid.get_variant() != uuid::Variant::RFC4122
{
return Ok(StatusCode::NOT_FOUND.into_response());
}

let profile = if get_cache().has_player_info(uuid).await? {
get_cache().get_player_info(uuid).await?
let profile = if state.cache.has_player_info(uuid).await? {
state.cache.get_player_info(uuid).await?
} else {
let p = fetch_tier(&uuid).await;
get_cache().set_player_info(uuid, p.clone()).await?;
state.cache.set_player_info(uuid, p.clone()).await?;
p
};

Expand All @@ -91,11 +99,11 @@ pub async fn get_tier(Path(uuid): Path<Uuid>) -> RouteResponse<impl IntoResponse
Ok(res)
}

pub async fn get_all() -> RouteResponse<Json<AllPlayerInfo>> {
pub async fn get_all(State(state): State<Arc<AppState>>) -> RouteResponse<Json<AllPlayerInfo>> {
let mut players = Vec::new();
let mut unknown = Vec::new();

for (uuid, profile) in get_cache().get_all_players().await? {
for (uuid, profile) in state.cache.get_all_players().await? {
match profile {
Some(p) => players.push(p),
None => unknown.push(uuid),
Expand All @@ -110,7 +118,10 @@ pub async fn get_all() -> RouteResponse<Json<AllPlayerInfo>> {
}

/// mctiers `search_profile` is not used here because their username cache can be outdated
pub async fn search_profile(Path(name): Path<String>) -> RouteResponse<impl IntoResponse> {
pub async fn search_profile(
Path(name): Path<String>,
State(state): State<Arc<AppState>>,
) -> RouteResponse<impl IntoResponse> {
let url = format!("https://api.mojang.com/users/profiles/minecraft/{name}");

let response = crate::CLIENT.get(url).send().await?;
Expand All @@ -120,7 +131,7 @@ pub async fn search_profile(Path(name): Path<String>) -> RouteResponse<impl Into

let response: MojangUUID = response.json().await?;

get_tier(Path(response.id))
get_tier(Path(response.id), State(state))
.await
.map(IntoResponse::into_response)
}
Expand Down

0 comments on commit 1619264

Please sign in to comment.