From a8c1043e221683874de9ef93e3fc60604496aada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fla=20=C3=87elik?= Date: Wed, 25 Dec 2024 10:26:34 +0300 Subject: [PATCH] env config --- Cargo.lock | 318 ++++++++++++++++++- Cargo.toml | 4 +- compute/src/{main.rs => launch.rs} | 34 +- compute/src/lib.rs | 5 +- launcher/Cargo.toml | 38 +++ launcher/src/lib.rs | 5 + launcher/src/main.rs | 492 +++++++++++++++++++++++++++++ launcher/src/utils/gemini.rs | 50 +++ launcher/src/utils/mod.rs | 3 + launcher/src/utils/openai.rs | 49 +++ launcher/src/utils/openrouter.rs | 56 ++++ 11 files changed, 1014 insertions(+), 40 deletions(-) rename compute/src/{main.rs => launch.rs} (82%) create mode 100644 launcher/Cargo.toml create mode 100644 launcher/src/lib.rs create mode 100644 launcher/src/main.rs create mode 100644 launcher/src/utils/gemini.rs create mode 100644 launcher/src/utils/mod.rs create mode 100644 launcher/src/utils/openai.rs create mode 100644 launcher/src/utils/openrouter.rs diff --git a/Cargo.lock b/Cargo.lock index f17e336..048c2b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -540,6 +540,46 @@ dependencies = [ "zeroize", ] +[[package]] +name = "clap" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.90", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + [[package]] name = "colorchoice" version = "1.0.3" @@ -565,6 +605,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode 0.3.6", + "lazy_static", + "libc", + "unicode-width 0.1.13", + "windows-sys 0.52.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -636,6 +689,47 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crossterm" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" +dependencies = [ + "bitflags 1.3.2", + "crossterm_winapi", + "libc", + "mio 0.8.11", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi 0.3.9", +] + +[[package]] +name = "crossterm" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" +dependencies = [ + "bitflags 2.6.0", + "crossterm_winapi", + "mio 1.0.3", + "parking_lot", + "rustix", + "signal-hook", + "signal-hook-mio", + "winapi 0.3.9", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -1005,6 +1099,32 @@ dependencies = [ "uuid", ] +[[package]] +name = "dkn-launcher" +version = "0.2.30" +dependencies = [ + "async-trait", + "clap", + "crossterm 0.28.1", + "dkn-compute", + "dkn-workflows", + "dotenvy", + "env_logger 0.11.5", + "eyre", + "hex", + "hex-literal", + "inquire", + "log", + "self-replace", + "self_update", + "serde", + "serde_json", + "tempfile", + "tokio 1.42.0", + "tokio-util 0.7.13", + "which", +] + [[package]] name = "dkn-monitor" version = "0.2.30" @@ -1090,6 +1210,12 @@ dependencies = [ "dtoa", ] +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "ecies" version = "0.2.6" @@ -1145,6 +1271,12 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "encode_unicode" version = "1.0.0" @@ -1498,6 +1630,15 @@ dependencies = [ "slab", ] +[[package]] +name = "fuzzy-matcher" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54614a3312934d066701a80f20f15fa3b56d67ac7722b39eea5b4c9dd1d66c94" +dependencies = [ + "thread_local", +] + [[package]] name = "fxhash" version = "0.2.1" @@ -1542,7 +1683,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" dependencies = [ - "unicode-width", + "unicode-width 0.1.13", ] [[package]] @@ -1796,6 +1937,15 @@ dependencies = [ "hmac 0.8.1", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "hostname" version = "0.3.1" @@ -1817,7 +1967,7 @@ dependencies = [ "markup5ever 0.12.1", "tendril", "thiserror 1.0.69", - "unicode-width", + "unicode-width 0.1.13", ] [[package]] @@ -2318,6 +2468,19 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "indicatif" +version = "0.17.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "inout" version = "0.1.3" @@ -2327,6 +2490,23 @@ dependencies = [ "generic-array", ] +[[package]] +name = "inquire" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fddf93031af70e75410a2511ec04d49e758ed2f26dad3404a934e0fb45cc12a" +dependencies = [ + "bitflags 2.6.0", + "crossterm 0.25.0", + "dyn-clone", + "fuzzy-matcher", + "fxhash", + "newline-converter", + "once_cell", + "unicode-segmentation", + "unicode-width 0.1.13", +] + [[package]] name = "instant" version = "0.1.13" @@ -3166,6 +3346,18 @@ dependencies = [ "winapi 0.2.8", ] +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.3" @@ -3173,6 +3365,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", + "log", "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -3341,6 +3534,15 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" +[[package]] +name = "newline-converter" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b6b097ecb1cbfed438542d16e84fd7ad9b0c76c8a65b7f9039212a3d14dc7f" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "nix" version = "0.26.4" @@ -3427,6 +3629,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.36.5" @@ -3865,6 +4073,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2110609fb863cdb367d4e69d6c43c81ba6a8c7d18e80082fe9f3ef16b23afeed" +[[package]] +name = "portable-atomic" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" + [[package]] name = "powerfmt" version = "0.2.0" @@ -3903,11 +4117,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46480520d1b77c9a3482d39939fcf96831537a250ec62d4fd8fbdf8e0302e781" dependencies = [ "csv", - "encode_unicode", + "encode_unicode 1.0.0", "is-terminal", "lazy_static", "term", - "unicode-width", + "unicode-width 0.1.13", ] [[package]] @@ -3975,6 +4189,15 @@ dependencies = [ "unsigned-varint", ] +[[package]] +name = "quick-xml" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11bafc859c6815fbaffbbbf4229ecb767ac913fecb27f9ad4343662e9ef099ea" +dependencies = [ + "memchr", +] + [[package]] name = "quinn" version = "0.11.6" @@ -4243,6 +4466,7 @@ dependencies = [ "base64 0.22.1", "bytes 1.9.0", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.7", @@ -4611,6 +4835,36 @@ dependencies = [ "smallvec", ] +[[package]] +name = "self-replace" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ec815b5eab420ab893f63393878d89c90fdd94c0bcc44c07abb8ad95552fb7" +dependencies = [ + "fastrand", + "tempfile", + "windows-sys 0.52.0", +] + +[[package]] +name = "self_update" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469a3970061380c19852269f393e74c0fe607a4e23d85267382cf25486aa8de5" +dependencies = [ + "hyper 1.5.1", + "indicatif", + "log", + "quick-xml", + "regex", + "reqwest 0.12.9", + "self-replace", + "semver", + "serde_json", + "tempfile", + "urlencoding", +] + [[package]] name = "semver" version = "1.0.24" @@ -4733,6 +4987,28 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio 0.8.11", + "mio 1.0.3", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -5097,6 +5373,16 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if 1.0.0", + "once_cell", +] + [[package]] name = "time" version = "0.1.45" @@ -5365,6 +5651,12 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "universal-hash" version = "0.5.1" @@ -5599,6 +5891,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9cad3279ade7346b96e38731a641d7343dd6a53d55083dd54eadfa5a1b38c6b" +dependencies = [ + "either", + "home", + "rustix", + "winsafe", +] + [[package]] name = "widestring" version = "1.1.0" @@ -5927,6 +6231,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "write16" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index bd8571a..bdb324a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [workspace] resolver = "2" -members = ["compute", "p2p", "workflows", "utils", "monitor"] +members = ["compute", "p2p", "workflows", "utils", "monitor", "launcher"] # FIXME: removing this breaks the workflows # compute node is the default member, until Oracle comes in # then, a Launcher will be the default member -default-members = ["compute"] +default-members = ["launcher"] [workspace.package] edition = "2021" diff --git a/compute/src/main.rs b/compute/src/launch.rs similarity index 82% rename from compute/src/main.rs rename to compute/src/launch.rs index e8478e6..f47afeb 100644 --- a/compute/src/main.rs +++ b/compute/src/launch.rs @@ -1,41 +1,11 @@ -use dkn_compute::*; +use crate::*; use dkn_workflows::DriaWorkflowsConfig; use eyre::Result; use std::env; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use workers::workflow::WorkflowsWorker; -#[tokio::main] -async fn main() -> Result<()> { - let dotenv_result = dotenvy::dotenv(); - - env_logger::builder() - .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) - .filter(None, log::LevelFilter::Off) - .filter_module("dkn_compute", log::LevelFilter::Info) - .filter_module("dkn_p2p", log::LevelFilter::Info) - .filter_module("dkn_workflows", log::LevelFilter::Info) - .parse_default_env() // reads RUST_LOG variable - .init(); - - log::info!( - r#" - -██████╗ ██████╗ ██╗ █████╗ -██╔══██╗██╔══██╗██║██╔══██╗ Dria Compute Node -██║ ██║██████╔╝██║███████║ v{DRIA_COMPUTE_NODE_VERSION} -██║ ██║██╔══██╗██║██╔══██║ https://dria.co -██████╔╝██║ ██║██║██║ ██║ -╚═════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ -"# - ); - - // log about env usage - match dotenv_result { - Ok(path) => log::info!("Loaded .env file at: {}", path.display()), - Err(e) => log::warn!("Could not load .env file: {}", e), - } - +pub async fn launch() -> Result<()> { // task tracker for multiple threads let task_tracker = TaskTracker::new(); let cancellation = CancellationToken::new(); diff --git a/compute/src/lib.rs b/compute/src/lib.rs index c399688..df080c0 100644 --- a/compute/src/lib.rs +++ b/compute/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod handlers; +pub mod launch; pub mod node; pub mod payloads; pub mod utils; @@ -9,7 +10,7 @@ pub mod workers; /// This value is attached within the published messages. pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION"); -pub use utils::refresh_dria_nodes; - pub use config::DriaComputeNodeConfig; +pub use launch::launch; pub use node::DriaComputeNode; +pub use utils::refresh_dria_nodes; diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml new file mode 100644 index 0000000..e6dc134 --- /dev/null +++ b/launcher/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "dkn-launcher" +version.workspace = true +edition.workspace = true +license.workspace = true +readme = "README.md" +authors = ["Erhan Tezcan "] + +[dependencies] +clap = { version = "4.5.20", features = ["derive"] } +crossterm = "0.28.1" +which = "7.0.0" +inquire = "0.7.5" +tempfile = "3.14.0" + +# async stuff +tokio-util.workspace = true +tokio.workspace = true +async-trait.workspace = true + +# serialize & deserialize +serde.workspace = true +serde_json.workspace = true + +# utilities +dotenvy.workspace = true +hex = "0.4.3" +hex-literal = "0.4.1" + +env_logger.workspace = true +log.workspace = true +eyre.workspace = true + +# dria subcrates +dkn-compute = { path = "../compute" } +dkn-workflows = { path = "../workflows" } +self-replace = "1.5.0" +self_update = "0.41.0" diff --git a/launcher/src/lib.rs b/launcher/src/lib.rs new file mode 100644 index 0000000..bb72121 --- /dev/null +++ b/launcher/src/lib.rs @@ -0,0 +1,5 @@ +pub mod utils; + +pub use utils::gemini; +pub use utils::openai; +pub use utils::openrouter; diff --git a/launcher/src/main.rs b/launcher/src/main.rs new file mode 100644 index 0000000..c24d4da --- /dev/null +++ b/launcher/src/main.rs @@ -0,0 +1,492 @@ +use dkn_compute::*; +use dkn_launcher::gemini::is_gemini_api_key_required; +use dkn_launcher::openai::is_openai_api_key_required; +use dkn_launcher::openrouter::is_openrouter_api_key_required; +use dkn_workflows::Model; + +use eyre::{Context, Result}; +use std::io::{self, Write}; +use std::path::PathBuf; +use std::process::Stdio; +use std::{ + env::{self, set_var}, + fs::OpenOptions, +}; +use tokio::process::{Child, Command}; +use which::which; + +use clap::{Parser, Subcommand}; +use crossterm::style::{Attribute, Color, SetAttribute, SetForegroundColor}; +use crossterm::terminal::{self, EnterAlternateScreen, LeaveAlternateScreen}; +use crossterm::{cursor, execute}; +use dotenvy::from_path_iter; +use inquire::{required, validator::MinLengthValidator, MultiSelect, Text}; +use self_update::cargo_crate_version; + +#[derive(Parser)] +#[command(name = "dkn-launcher", version, about = "Dria Knowledge Network launcher", long_about = None)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + // Configure the environment variables for the node. + Configure { + #[arg( + short, + long, + help = "Path to the .env file", + default_value = ".env", + required = false + )] + path: PathBuf, + #[arg(short, long, help = "Edit all env variables", required = false)] + all: bool, + }, + // Launch the compute node. + Compute { + #[arg( + short, + long, + help = "Path to the .env file", + default_value = ".env", + required = false, + value_parser = clap::builder::FalseyValueParser::new() + )] + path: PathBuf, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::try_parse(); + + match cli { + Ok(cli) => { + match &cli.command { + Commands::Configure { path, all } => { + log::info!("Configuring the environment variables from: {:?}", path); + // terminal setup + terminal::enable_raw_mode()?; + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen, cursor::MoveTo(0, 0))?; + execute!( + stdout, + SetForegroundColor(Color::Blue), + SetAttribute(Attribute::Italic) + )?; + + if *all { + configure_all(&path)?; + } else { + configure(&path)?; + } + + terminal::disable_raw_mode()?; + execute!(stdout, LeaveAlternateScreen)?; + } + Commands::Compute { path } => { + // update the launcher + // self_update().await?; + + // run models + // run_ollama().await?; + + // launch the node + // launch_compute_node(&path).await?; + } + }; + } + Err(e) => log::warn!("Failed to parse command line argument: {}", e), + } + + Ok(()) +} + +fn configure(path: &PathBuf) -> Result<()> { + // read env + let mut env_vars = from_path_iter(path) + .expect("Unable to read env") + .map(|values| values.expect("Unable to map env vars")) + .collect::>(); + + // holds selected models to set related api key later + let mut selected_models: Vec; + let mut selected_models_str = String::new(); + + // TODO: hold models in different arrays according to their providers + let models = vec![ + Model::NousTheta, + Model::Phi3Medium, + Model::Phi3Medium128k, + Model::Phi3_5Mini, + Model::Phi3_5MiniFp16, + Model::Gemma2_9B, + Model::Gemma2_9BFp16, + Model::Llama3_1_8B, + Model::Llama3_1_8Bq8, + Model::Llama3_1_8Bf16, + Model::Llama3_1_70B, + Model::Llama3_1_70Bq8, + Model::Llama3_2_1B, + Model::Llama3_2_3B, + Model::Qwen2_5_7B, + Model::Qwen2_5_7Bf16, + Model::Qwen2_5_32Bf16, + Model::Qwen2_5Coder1_5B, + Model::Qwen2_5coder7B, + Model::Qwen2_5oder7Bq8, + Model::Qwen2_5coder7Bf16, + Model::DeepSeekCoder6_7B, + Model::Mixtral8_7b, + Model::GPT4Turbo, + Model::GPT4o, + Model::GPT4oMini, + Model::O1Preview, + Model::O1Mini, + Model::Gemini15ProExp0827, + Model::Gemini15Pro, + Model::Gemini15Flash, + Model::Gemini10Pro, + Model::Gemma2_2bIt, + Model::Gemma2_9bIt, + Model::Gemma2_27bIt, + ]; + + // loop through env vars + for (key, val) in env_vars.iter_mut() { + println!("Key: '{}', Value: '{}'", key, val); + + // ask only for the empty values + if val.is_empty() { + match key.as_str() { + "OPENAI_API_KEY" | "GEMINI_API_KEY" | "OPENROUTER_API_KEY" => { + continue; + } + // ask for the wallet secret key + "DKN_WALLET_SECRET_KEY" => { + let new_secret_key = Text::new(key) + .with_validator(required!("Wallet secret key is required")) + .prompt(); + + // update the secret key in env_vars + if let Ok(new_secret_key) = new_secret_key { + *val = new_secret_key.clone(); + set_var(key, new_secret_key); + } + } + // at least one model must be selected + "DKN_MODELS" => { + let dkn_models = MultiSelect::new(key, models.clone()) + .with_validator(MinLengthValidator::new(1)) + .prompt(); + + if let Ok(dkn_models) = dkn_models { + selected_models = dkn_models + .into_iter() + .map(|model| model.to_string()) + .collect::>(); + + selected_models_str = selected_models.join(","); + + // set the selected models + *val = selected_models_str.clone(); + set_var(key, selected_models_str.clone()); + } + } + _ => { + // update value for other env vars + let new_value = Text::new(key).prompt(); + if let Ok(new_value) = new_value { + *val = new_value.clone(); + set_var(key, new_value); + } + } + } + } + } + + // api key setting according to the selected models + is_openrouter_api_key_required(&selected_models_str, &mut env_vars); + is_openai_api_key_required(&selected_models_str, &mut env_vars); + is_gemini_api_key_required(&selected_models_str, &mut env_vars); + + // open file for writing + let mut file = OpenOptions::new() + .write(true) + .open(path) + .expect("Unable to open file"); + + // write new values to the .env file + for (key, val) in &env_vars { + writeln!(file, "{}={}", key, val).expect("Unable to write to file"); + } + + Ok(()) +} + +fn configure_all(path: &PathBuf) -> Result<()> { + // read env + let mut env_vars = from_path_iter(path) + .expect("Unable to read env") + .map(|values| values.expect("Unable to map env vars")) + .collect::>(); + + // holds selected models to set related api key later + let mut selected_models: Vec; + let mut selected_models_str = String::new(); + + // models + let models = vec![ + Model::NousTheta, + Model::Phi3Medium, + Model::Phi3Medium128k, + Model::Phi3_5Mini, + Model::Phi3_5MiniFp16, + Model::Gemma2_9B, + Model::Gemma2_9BFp16, + Model::Llama3_1_8B, + Model::Llama3_1_8Bq8, + Model::Llama3_1_8Bf16, + Model::Llama3_1_70B, + Model::Llama3_1_70Bq8, + Model::Llama3_2_1B, + Model::Llama3_2_3B, + Model::Qwen2_5_7B, + Model::Qwen2_5_7Bf16, + Model::Qwen2_5_32Bf16, + Model::Qwen2_5Coder1_5B, + Model::Qwen2_5coder7B, + Model::Qwen2_5oder7Bq8, + Model::Qwen2_5coder7Bf16, + Model::DeepSeekCoder6_7B, + Model::Mixtral8_7b, + Model::GPT4Turbo, + Model::GPT4o, + Model::GPT4oMini, + Model::O1Preview, + Model::O1Mini, + Model::Gemini15ProExp0827, + Model::Gemini15Pro, + Model::Gemini15Flash, + Model::Gemini10Pro, + Model::Gemma2_2bIt, + Model::Gemma2_9bIt, + Model::Gemma2_27bIt, + ]; + + // loop through env vars + for (key, val) in env_vars.iter_mut() { + if val.is_empty() { + match key.as_str() { + // skip api keys + "OPENAI_API_KEY" | "GEMINI_API_KEY" | "OPENROUTER_API_KEY" => { + continue; + } + + // ask for the wallet secret key + "DKN_WALLET_SECRET_KEY" => { + let new_secret_key = Text::new(key) + .with_validator(required!("Wallet secret key is required")) + .prompt(); + + // update the secret key in env_vars + if let Ok(new_secret_key) = new_secret_key { + *val = new_secret_key.clone(); + set_var(key, new_secret_key); + } + } + // at least one model must be selected + "DKN_MODELS" => { + let dkn_models = MultiSelect::new(key, models.clone()) + .with_validator(MinLengthValidator::new(1)) + .prompt(); + + if let Ok(dkn_models) = dkn_models { + selected_models = dkn_models + .into_iter() + .map(|model| model.to_string()) + .collect::>(); + + selected_models_str = selected_models.join(","); + + // set the selected models + *val = selected_models_str.clone(); + set_var(key, selected_models_str.clone()); + } + } + _ => { + // for other values + let new_value = Text::new(key).prompt(); + if let Ok(new_value) = new_value { + *val = new_value.clone(); + set_var(key, new_value); + } + } + } + } + } + + // api key setting according to the selected models + is_openrouter_api_key_required(&selected_models_str, &mut env_vars); + is_openai_api_key_required(&selected_models_str, &mut env_vars); + is_gemini_api_key_required(&selected_models_str, &mut env_vars); + + // open file for writing + let mut file = OpenOptions::new() + .write(true) + .truncate(true) // clear the file before writing + .open(path) + .expect("Unable to open file"); + + // write new values to the .env file + for (key, val) in &env_vars { + writeln!(file, "{}={}", key, val).expect("Unable to write to file"); + } + + Ok(()) +} + +pub async fn run_ollama() -> Result { + // find the path to binary + let exe_path = which("ollama").wrap_err("could not find Ollama executable")?; + + log::debug!("Using Ollama executable at {:?}", exe_path); + + // ollama requires the OLLAMA_HOST environment variable to be set before launch + env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434"); + let command = Command::new(exe_path) + .arg("serve") + .stdout(Stdio::null()) // Ignore the output for simplicity + // if ollama arent running in the background + .spawn() + .wrap_err("could not spawn Ollama")?; + // set host later to default value + + Ok(command) +} + +async fn launch_compute_node(path: &PathBuf) -> Result<()> { + let dotenv_result = dotenvy::from_path(path); + + env_logger::builder() + .format_timestamp(Some(env_logger::TimestampPrecision::Millis)) + .filter(None, log::LevelFilter::Off) + .filter_module("dkn_compute", log::LevelFilter::Info) + .filter_module("dkn_p2p", log::LevelFilter::Info) + .filter_module("dkn_workflows", log::LevelFilter::Info) + .parse_default_env() // reads RUST_LOG variable + .init(); + + log::info!( + r#" + +██████╗ ██████╗ ██╗ █████╗ +██╔══██╗██╔══██╗██║██╔══██╗ Dria Compute Node +██║ ██║██████╔╝██║███████║ v{DRIA_COMPUTE_NODE_VERSION} +██║ ██║██╔══██╗██║██╔══██║ https://dria.co +██████╔╝██║ ██║██║██║ ██║ +╚═════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ +"# + ); + + // log about env usage + match dotenv_result { + Ok(path) => log::info!("Loaded .env file at: {:?}", path), + Err(e) => log::warn!("Could not load .env file: {}", e), + } + + // launch the compute node + let _ = launch().await; + + Ok(()) +} + +async fn self_update() -> Result<()> { + // TODO: get latest release from github + let status = self_update::backends::github::Update::configure() + .repo_owner("firstbatchxyz") + .repo_name("dkn-launcher") + .show_output(true) + .current_version(cargo_crate_version!()) + .build() + .expect("Unable to build update") + .update(); + + match status { + Ok(status) => log::info!("Launcher updated with status: {}", status.version()), + Err(e) => log::warn!("Failed to update: {}", e), + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::NamedTempFile; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn test_run_models() { + let mut child = run_ollama().await.unwrap(); + + // wait for 10 seconds + println!("Waiting for 10 seconds..."); + sleep(Duration::from_secs(10)).await; + + // kill the process + if let Err(e) = child.kill().await { + log::error!("Failed to kill Ollama process: {}", e); + } else { + log::info!("Ollama process killed."); + } + } + + #[test] + fn test_cli_parse_configure() { + let args = vec!["dkn-launcher", "configure"]; + let cli = Cli::try_parse_from(args).unwrap(); + + match cli.command { + Commands::Configure { path, all: _ } => { + assert_eq!(path, PathBuf::from(".env")); + } + _ => panic!("Expected Configure command"), + } + } + + #[test] + fn test_cli_parse_configure_all() { + let args = vec!["dkn-launcher", "configure", "--all"]; + let cli = Cli::try_parse_from(args).unwrap(); + + match cli.command { + Commands::Configure { path, all: _ } => { + assert_eq!(path, PathBuf::from(".env")); + } + _ => panic!("Expected Configure command"), + } + } + + #[test] + fn test_configure() -> Result<()> { + let temp_file = NamedTempFile::new()?; + let path = temp_file.path().to_path_buf(); + + // Write test env content + fs::write(&path, "DKN_WALLET_SECRET_KEY=\nDKN_MODELS=\n")?; + + // Mock terminal input + let _ = configure(&path)?; + + // Read updated env file + let content = fs::read_to_string(&path)?; + assert!(content.contains("DKN_WALLET_SECRET_KEY=")); + assert!(content.contains("DKN_MODELS=")); + + Ok(()) + } +} diff --git a/launcher/src/utils/gemini.rs b/launcher/src/utils/gemini.rs new file mode 100644 index 0000000..c015431 --- /dev/null +++ b/launcher/src/utils/gemini.rs @@ -0,0 +1,50 @@ +use inquire::{required, Text}; +use std::env::set_var; + +pub fn is_gemini_api_key_required( + selected_models: &str, + env_vars: &mut Vec<(String, String)>, +) -> bool { + // if Gemini model is selected, ask for the api key + if selected_models.contains("gemini") || selected_models.contains("gemma") { + let gemini_api_key = Text::new("GEMINI_API_KEY") + .with_validator(required!("Gemini API key is required")) + .prompt(); + + // set api key in env_vars + if let Ok(gemini_api_key) = gemini_api_key { + if let Some((_, new_value)) = env_vars.iter_mut().find(|(k, _)| k == "GEMINI_API_KEY") { + *new_value = gemini_api_key.clone(); + set_var("GEMINI_API_KEY", gemini_api_key); + } + } + true + } else { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::NamedTempFile; + + #[test] + #[ignore = "requires Gemini"] + fn test_is_gemini_api_key_required_with_gpt_model() { + // create a temp env file + let temp_file = NamedTempFile::new(); + if let Ok(path) = temp_file { + let path = path.path().to_path_buf(); + + // write test env content + fs::write(&path, "GEMINI_API_KEY=\n").expect("Unable to write temp file"); + + let selected_models = "gemini-1.5-pro,gemma-2-2b-it"; + let mut env_vars = vec![("GEMINI_API_KEY".to_string(), "".to_string())]; + let required = is_gemini_api_key_required(selected_models, &mut env_vars); + assert!(required); + } + } +} diff --git a/launcher/src/utils/mod.rs b/launcher/src/utils/mod.rs new file mode 100644 index 0000000..63bbde8 --- /dev/null +++ b/launcher/src/utils/mod.rs @@ -0,0 +1,3 @@ +pub mod gemini; +pub mod openai; +pub mod openrouter; diff --git a/launcher/src/utils/openai.rs b/launcher/src/utils/openai.rs new file mode 100644 index 0000000..dc85c7b --- /dev/null +++ b/launcher/src/utils/openai.rs @@ -0,0 +1,49 @@ +use inquire::{required, Text}; +use std::env::set_var; + +pub fn is_openai_api_key_required( + selected_models: &str, + env_vars: &mut Vec<(String, String)>, +) -> bool { + // if OPENAI model is selected, ask for the api key + if selected_models.contains("gpt") || selected_models.contains("o1") { + let openai_api_key = Text::new("OPENAI_API_KEY") + .with_validator(required!("OpenAI API key is required")) + .prompt(); + + if let Ok(openai_api_key) = openai_api_key { + if let Some((_, new_value)) = env_vars.iter_mut().find(|(k, _)| k == "OPENAI_API_KEY") { + *new_value = openai_api_key.clone(); + set_var("OPENAI_API_KEY", openai_api_key); + } + } + true + } else { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::NamedTempFile; + + #[test] + #[ignore = "requires OpenAI"] + fn test_is_openai_api_key_required_with_gpt_model() { + // create a temp env file + let temp_file = NamedTempFile::new(); + if let Ok(path) = temp_file { + let path = path.path().to_path_buf(); + + // write test env content + fs::write(&path, "OPENAI_API_KEY=\n").expect("Unable to write temp file"); + + let selected_models = "gpt-4o-mini"; + let mut env_vars = vec![("OPENAI_API_KEY".to_string(), "".to_string())]; + let required = is_openai_api_key_required(selected_models, &mut env_vars); + assert!(required); + } + } +} diff --git a/launcher/src/utils/openrouter.rs b/launcher/src/utils/openrouter.rs new file mode 100644 index 0000000..b9b3d0a --- /dev/null +++ b/launcher/src/utils/openrouter.rs @@ -0,0 +1,56 @@ +use inquire::{required, Text}; +use std::env::set_var; + +pub fn is_openrouter_api_key_required( + selected_models: &str, + env_vars: &mut Vec<(String, String)>, +) -> bool { + // if open router model is selected, ask for the api key + // TODO: check openrouter model names + if selected_models.contains("claude") + || selected_models.contains("qwen") + || selected_models.contains("deepseek") + || selected_models.contains("qwq") + { + let openrouter_api_key = Text::new("OPENROUTER_API_KEY") + .with_validator(required!("OpenRouter API key is required")) + .prompt(); + + if let Ok(openrouter_api_key) = openrouter_api_key { + if let Some((_, new_value)) = + env_vars.iter_mut().find(|(k, _)| k == "OPENROUTER_API_KEY") + { + *new_value = openrouter_api_key.clone(); + set_var("OPENROUTER_API_KEY", openrouter_api_key); + } + } + true + } else { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::NamedTempFile; + + #[test] + #[ignore = "requires Open Router"] + fn test_is_openai_api_key_required_with_gpt_model() { + // create a temp env file + let temp_file = NamedTempFile::new(); + if let Ok(path) = temp_file { + let path = path.path().to_path_buf(); + + // write test env content + fs::write(&path, "OPENROUTER_API_KEY=\n").expect("Unable to write temp file"); + + let selected_models = "qwen-2.5-72b-instruct,qwq-32b-preview"; + let mut env_vars = vec![("OPENROUTER_API_KEY".to_string(), "".to_string())]; + let required = is_openrouter_api_key_required(selected_models, &mut env_vars); + assert!(required); + } + } +}