Skip to content

Commit

Permalink
feat: support a separate aichat config directory
Browse files Browse the repository at this point in the history
Closes #2
  • Loading branch information
arcuru committed Mar 24, 2024
1 parent 2c43acb commit 6579ae6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ homeserver_url: https://matrix.org
username: "headjack"
password: "" # Optional, if not given it will ask for it on first run
allow_list: "" # Regex for allowed accounts.
aichat_config_dir: "$AICHAT_CONFIG_DIR" # Optional, for using a separate aichat config
```
## Running
Expand Down
28 changes: 20 additions & 8 deletions src/aichat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@ use std::process::Command;

pub struct AiChat {
binary_location: String,
config_dir: Option<String>,
}

impl Default for AiChat {
fn default() -> Self {
AiChat::new("aichat".to_string())
AiChat::new("aichat".to_string(), None)
}
}

impl AiChat {
pub fn new(binary_location: String) -> Self {
AiChat { binary_location }
pub fn new(binary_location: String, config_dir: Option<String>) -> Self {
AiChat {
binary_location,
config_dir,
}
}

/// List the models available to the aichat binary
pub fn list_models(&self) -> Vec<String> {
// Run the binary with the `list` argument
let output = Command::new(&self.binary_location)
.arg("--list-models")
.output()
.expect("Failed to execute command");
let mut command = Command::new(self.binary_location.clone());
command.arg("--list-models");

// Add the config dir if it exists
if let Some(config_dir) = &self.config_dir {
command.env("AICHAT_CONFIG_DIR", config_dir);
}

let output = command.output().expect("Failed to execute command");

// split each line of the output into it's own string and return
output
Expand All @@ -36,6 +45,9 @@ impl AiChat {
if let Some(model) = model {
command.arg("--model").arg(model);
}
if let Some(config_dir) = &self.config_dir {
command.env("AICHAT_CONFIG_DIR", config_dir);
}
command.arg("--").arg(prompt);
eprintln!("Running command: {:?}", command);

Expand Down
21 changes: 15 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ struct Config {
password: Option<String>,
/// Allow list of which accounts we will respond to
allow_list: Option<String>,
/// Set the config directory for aichat
/// Allows for multiple instances setups of aichat
aichat_config_dir: Option<String>,
}

lazy_static! {
Expand Down Expand Up @@ -435,7 +438,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
// But we need to read the context to figure out the model to use
let (_, model) = get_context(&room).await.unwrap();

if let Ok(result) = AiChat::default().execute(model, input.to_string()) {
if let Ok(result) = get_backend().execute(model, input.to_string()) {
// Add the prefix ".response:\n" to the result
// That way we can identify our own responses and ignore them for context
let result = format!(".response:\n{}", result);
Expand Down Expand Up @@ -473,8 +476,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
// Get the second word in the command
let model = text.split_whitespace().nth(1);
if let Some(model) = model {
// Verify this model is available
let models = AiChat::new("aichat".to_string()).list_models();
let models = get_backend().list_models();
if models.contains(&model.to_string()) {
// Set the model
let response = format!(".model set to {}", model);
Expand Down Expand Up @@ -502,7 +504,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
"list" => {
let response = format!(
".models available:\n\n{}",
AiChat::new("aichat".to_string()).list_models().join("\n")
get_backend().list_models().join("\n")
);
room.send(RoomMessageEventContent::text_plain(response))
.await
Expand All @@ -520,7 +522,7 @@ async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) {
let prefix = format!("Here is the full text of our ongoing conversation. Your name is {}, and your messages are prefixed by {}:. My name is {}, and my messages are prefixed by {}:. Send the next response in this conversation. Do not prefix your response with your name or any other text. Do not greet me again if you've already done so. Send only the text of your response.\n",
room.client().user_id().unwrap(), room.client().user_id().unwrap(), event.sender, event.sender);
context.insert_str(0, &prefix);
if let Ok(result) = AiChat::default().execute(model, context) {
if let Ok(result) = get_backend().execute(model, context) {
let content = RoomMessageEventContent::text_plain(result);
room.send(content).await.unwrap();
}
Expand All @@ -532,6 +534,12 @@ fn is_command(text: &str) -> bool {
text.starts_with('.') && !text.starts_with("..")
}

/// Returns the backend based on the global config
fn get_backend() -> AiChat {
let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
AiChat::new("aichat".to_string(), config.aichat_config_dir.clone())
}

/// Gets the context of the current conversation
/// Returns a model if it was ever entered
async fn get_context(room: &Room) -> Result<(String, Option<String>), ()> {
Expand Down Expand Up @@ -559,7 +567,8 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>), ()> {
if text_content.body.starts_with(".model") {
let model = text_content.body.split_whitespace().nth(1);
if let Some(model) = model {
let models = AiChat::new("aichat".to_string()).list_models();
// Add the config_dir from the global config
let models = get_backend().list_models();
if models.contains(&model.to_string()) {
model_response = Some(model.to_string());
}
Expand Down

0 comments on commit 6579ae6

Please sign in to comment.