diff --git a/README.md b/README.md index f58ed18..eae7171 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ homeserver_url: https://matrix.org username: "headjack" password: "" # Optional, if not given it will ask for it on first run allow_list: "" # Regex for allowed accounts. +aichat_config_dir: "$AICHAT_CONFIG_DIR" # Optional, for using a separate aichat config ``` ## Running diff --git a/src/aichat.rs b/src/aichat.rs index 3ff972e..6e67732 100644 --- a/src/aichat.rs +++ b/src/aichat.rs @@ -2,25 +2,34 @@ use std::process::Command; pub struct AiChat { binary_location: String, + config_dir: Option, } impl Default for AiChat { fn default() -> Self { - AiChat::new("aichat".to_string()) + AiChat::new("aichat".to_string(), None) } } impl AiChat { - pub fn new(binary_location: String) -> Self { - AiChat { binary_location } + pub fn new(binary_location: String, config_dir: Option) -> Self { + AiChat { + binary_location, + config_dir, + } } + /// List the models available to the aichat binary pub fn list_models(&self) -> Vec { - // Run the binary with the `list` argument - let output = Command::new(&self.binary_location) - .arg("--list-models") - .output() - .expect("Failed to execute command"); + let mut command = Command::new(self.binary_location.clone()); + command.arg("--list-models"); + + // Add the config dir if it exists + if let Some(config_dir) = &self.config_dir { + command.env("AICHAT_CONFIG_DIR", config_dir); + } + + let output = command.output().expect("Failed to execute command"); // split each line of the output into it's own string and return output @@ -36,6 +45,9 @@ impl AiChat { if let Some(model) = model { command.arg("--model").arg(model); } + if let Some(config_dir) = &self.config_dir { + command.env("AICHAT_CONFIG_DIR", config_dir); + } command.arg("--").arg(prompt); eprintln!("Running command: {:?}", command); diff --git a/src/main.rs b/src/main.rs index fa7379f..5762dfe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -77,6 +77,9 @@ struct Config { password: Option, /// Allow list of which accounts we will respond to allow_list: Option, + /// Set the config directory for aichat + /// Allows for multiple instances setups of aichat + aichat_config_dir: Option, } lazy_static! { @@ -435,7 +438,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { // But we need to read the context to figure out the model to use let (_, model) = get_context(&room).await.unwrap(); - if let Ok(result) = AiChat::default().execute(model, input.to_string()) { + if let Ok(result) = get_backend().execute(model, input.to_string()) { // Add the prefix ".response:\n" to the result // That way we can identify our own responses and ignore them for context let result = format!(".response:\n{}", result); @@ -473,8 +476,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { // Get the second word in the command let model = text.split_whitespace().nth(1); if let Some(model) = model { - // Verify this model is available - let models = AiChat::new("aichat".to_string()).list_models(); + let models = get_backend().list_models(); if models.contains(&model.to_string()) { // Set the model let response = format!(".model set to {}", model); @@ -502,7 +504,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { "list" => { let response = format!( ".models available:\n\n{}", - AiChat::new("aichat".to_string()).list_models().join("\n") + get_backend().list_models().join("\n") ); room.send(RoomMessageEventContent::text_plain(response)) .await @@ -520,7 +522,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { let prefix = format!("Here is the full text of our ongoing conversation. Your name is {}, and your messages are prefixed by {}:. My name is {}, and my messages are prefixed by {}:. Send the next response in this conversation. Do not prefix your response with your name or any other text. Do not greet me again if you've already done so. Send only the text of your response.\n", room.client().user_id().unwrap(), room.client().user_id().unwrap(), event.sender, event.sender); context.insert_str(0, &prefix); - if let Ok(result) = AiChat::default().execute(model, context) { + if let Ok(result) = get_backend().execute(model, context) { let content = RoomMessageEventContent::text_plain(result); room.send(content).await.unwrap(); } @@ -532,6 +534,12 @@ fn is_command(text: &str) -> bool { text.starts_with('.') && !text.starts_with("..") } +/// Returns the backend based on the global config +fn get_backend() -> AiChat { + let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap(); + AiChat::new("aichat".to_string(), config.aichat_config_dir.clone()) +} + /// Gets the context of the current conversation /// Returns a model if it was ever entered async fn get_context(room: &Room) -> Result<(String, Option), ()> { @@ -559,7 +567,8 @@ async fn get_context(room: &Room) -> Result<(String, Option), ()> { if text_content.body.starts_with(".model") { let model = text_content.body.split_whitespace().nth(1); if let Some(model) = model { - let models = AiChat::new("aichat".to_string()).list_models(); + // Add the config_dir from the global config + let models = get_backend().list_models(); if models.contains(&model.to_string()) { model_response = Some(model.to_string()); }