diff --git a/.gitignore b/.gitignore index b23046e..630b7aa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ # .gitignore /target .env +.env.test diff --git a/Cargo.toml b/Cargo.toml index 9147744..e7f17e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,11 +10,11 @@ dotenvy = "0.15" rocket = { version = "0.5.1", features = ["tls",] } sqlx = { version = "0.8.2", features = ["runtime-tokio-native-tls", "mysql"] } tokio = { version = "1.42.0", features = ["full"] } +url = "2.5" [lib] name = "bville_recycle" path = "src/lib.rs" [dev-dependencies] -glob = "0.3.1" -url = "2.5" \ No newline at end of file +glob = "0.3.1" \ No newline at end of file diff --git a/src/db_utils.rs b/src/db_utils.rs new file mode 100644 index 0000000..6a1d89d --- /dev/null +++ b/src/db_utils.rs @@ -0,0 +1,64 @@ +// src/db_utils.rs + +use dotenvy::dotenv; +use sqlx::MySqlPool; +use std::process::Command; +use url::Url; + +pub async fn initialize_database(database_url: &str) -> Result<(), String> { + dotenv().ok(); + + let parsed_url = Url::parse(database_url).map_err(|e| format!("Invalid DATABASE_URL: {}", e))?; + + let username: &str = parsed_url.username(); + let password: &str = parsed_url.password().ok_or("Password not found in DATABASE_URL")?; + let host: &str = parsed_url + .host_str() + .ok_or("Host not found in DATABASE_URL")?; + let database_name = parsed_url.path().trim_start_matches('/'); + if database_name.is_empty() { + return Err("Database name not found in DATABASE_URL".into()); + } + + // Commands to configure the database + let commands = [ + &format!("sudo mariadb -e \"CREATE DATABASE IF NOT EXISTS {database_name}\""), + &format!( + "sudo mariadb -e \"DROP USER IF EXISTS '{username}'@'{host}'\"" + ), + &format!( + "sudo mariadb -e \"CREATE USER IF NOT EXISTS '{username}'@'{host}' IDENTIFIED BY '{password}'\"" + ), + &format!( + "sudo mariadb -e \"GRANT ALL PRIVILEGES ON {database_name}.* TO '{username}'@'{host}'\"" + ), + "sudo mariadb -e \"FLUSH PRIVILEGES\"", + ]; + + for cmd in &commands { + let output = Command::new("sh") + .arg("-c") + .arg(cmd) + .output() + .map_err(|e| format!("Failed to execute command: {}. Error: {}", cmd, e))?; + + if !output.status.success() { + return Err(format!( + "Command failed: {}. Stderr: {}", + cmd, + String::from_utf8_lossy(&output.stderr) + )); + } + } + + Ok(()) +} + + +pub async fn verify_database_connection(pool: &MySqlPool) -> Result<(), String> { + sqlx::query("SELECT 1") + .fetch_one(pool) + .await + .map_err(|e| format!("Database connection test failed: {}", e))?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index f4952be..595db9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ use sqlx::{MySql, Pool}; use std::env; pub mod db_initializer; +pub mod db_utils; #[get("/")] pub fn map_root() -> &'static str { @@ -29,15 +30,12 @@ pub fn about() -> &'static str { #[get("/db_test")] pub async fn db_test(pool: &State>) -> &'static str { - let row: (i32,) = sqlx::query_as("SELECT 1") - .fetch_one(pool.inner()) - .await - .expect("Failed to execute query"); - - if row.0 == 1 { - "Database is working" - } else { - "Database test failed" + match db_utils::verify_database_connection(pool.inner()).await { + Ok(_) => "Database is working", + Err(err) => { + eprintln!("Database test failed: {}", err); + "Database test failed" + } } } @@ -45,14 +43,20 @@ pub async fn rocket() -> Rocket { dotenv().ok(); let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + // Initialize the database if needed + if let Err(err) = db_utils::initialize_database(&database_url).await { + panic!("Failed to initialize database: {}", err); + } + let pool = Pool::::connect(&database_url) .await .expect("Failed to connect to database"); - // Initialize the database if needed - db_initializer::initialize_database(&pool, "sql/init.sql") - .await - .expect("Failed to initialize database"); + // Verify database connection + if let Err(err) = db_utils::verify_database_connection(&pool).await { + panic!("Database connection verification failed: {}", err); + } rocket::build() .manage(pool) diff --git a/tests/crud.rs b/tests/crud.rs index 11e1165..5756865 100644 --- a/tests/crud.rs +++ b/tests/crud.rs @@ -10,81 +10,18 @@ #![forbid(unsafe_code)] -// use bville_recycle::{db_initializer, rocket}; use dotenvy::dotenv; -// use rocket::local::asynchronous::Client as AsyncClient; -// use sqlx::MySqlPool; use std::env; -use std::process::Command; -use url::Url; +use bville_recycle::db_utils; #[tokio::test] async fn test_database_initialization() { - dotenv().ok(); - - // Retrieve the DATABASE_URL from the environment - let database_url: String = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - - // Parse the DATABASE_URL using the `url` crate - let pared_url: Url = Url::parse(&database_url).expect("Invalid DATABASE_URL format. Failed to parse DATABASE_URL"); - - // Extract components from the parsed URL - let scheme = pared_url.scheme(); - if scheme.is_empty() { - panic!("Scheme not found in DATABASE_URL"); - } - - let username = pared_url.username(); - if username.is_empty() { - panic!("Username not found in DATABASE_URL"); - } - - let password: &str = pared_url.password().expect("Password not found in DATABASE_URL"); - - let host: &str = pared_url.host_str().expect("Host not found in DATABASE_URL"); - - let path: &str = pared_url.path().trim_start_matches("/"); - if path.is_empty() { - panic!("Path (database name) not found in DATABASE_URL"); - } - - // Configure test database - let commands = [ - { - println!("Updating package lists..."); - "sudo apt-get update -y" - }, - { - println!("Installing MariaDB server..."); - "sudo apt-get install -y mariadb-server" - }, - { - println!("Starting MariaDB service..."); - "sudo service mariadb start" - }, - &format!("sudo mariadb -e \"CREATE DATABASE IF NOT EXISTS {path}\""), - &format!("sudo mariadb -e \"DROP USER IF EXISTS '{username}'@'{host}'\""), - &format!("sudo mariadb -e \"CREATE USER '{username}'@'{host}'\""), - &format!("sudo mariadb -e \"SET PASSWORD FOR '{username}'@'{host}' = PASSWORD('{password}')\""), - &format!("sudo mariadb -e \"GRANT ALL PRIVILEGES ON {path}.* TO '{username}'@'{host}'\""), - &format!("sudo mariadb -e \"FLUSH PRIVILEGES\""), - ]; - - for cmd in commands { - let output: std::process::Output = Command::new("sh") - .arg("-c") - .arg(cmd) - .output() - .expect(&format!("Failed to execute command: {}", cmd)); - - if !output.status.success() { - panic!("Command failed: {}\nError: {}", cmd, - String::from_utf8_lossy(&output.stderr)); - } - } - - - + dotenvy::from_filename(".env.test").ok(); + let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + db_utils::initialize_database(&database_url) + .await + .expect("Database initialization failed in test"); } +