Skip to content

Commit

Permalink
enforce top_n in postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Jun 2, 2024
1 parent 4f78e45 commit 577e643
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/http/routers/v1/at_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,20 @@ pub async fn handle_v1_command_preview(
}
};

let top_n = 10; // sync with top_n in chats
let top_n = 7; // sync with top_n in chats

let at_context = AtCommandsContext::new(global_context.clone()).await;

let (messages_for_postprocessing, vec_highlights) = execute_at_commands_in_query(&mut query, &at_context, false, top_n).await;

let rag_n_ctx = max_tokens_for_rag_chat(recommended_model_record.n_ctx, 512); // real maxgen may be different -- comes from request
let processed = postprocess_at_results2(
global_context.clone(),
messages_for_postprocessing,
tokenizer_arc.clone(),
rag_n_ctx,
false,
top_n,
).await;
let mut preview: Vec<ChatMessage> = vec![];
if processed.len() > 0 {
Expand Down Expand Up @@ -184,7 +185,7 @@ async fn command_completion(
Some((x, idx)) => (x.clone(), idx),
None => return (vec![], false, -1, -1),
};

let cmd = match at_command_names.iter().find(|x|x == &&q_cmd.value).and_then(|x|context.at_commands.get(x)) {
Some(x) => x,
None => {
Expand Down
14 changes: 11 additions & 3 deletions src/scratchpads/chat_utils_rag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ pub struct PostprocessSettings {
pub comments_propogate_up_coef: f32, // mark comments above a symbol as useful, with this coef
pub close_small_gaps: bool,
pub take_floor: f32, // take/dont value
pub max_files_n: usize, // don't produce more than n files in output
}

impl PostprocessSettings {
Expand All @@ -206,6 +207,7 @@ impl PostprocessSettings {
close_small_gaps: true,
comments_propogate_up_coef: 0.99,
take_floor: 0.0,
max_files_n: 10,
}
}
}
Expand Down Expand Up @@ -501,10 +503,12 @@ pub async fn postprocess_at_results2(
tokenizer: Arc<RwLock<Tokenizer>>,
tokens_limit: usize,
single_file_mode: bool,
max_files_n: usize,
) -> Vec<ContextFile> {
let files_markup = postprocess_rag_load_ast_markup(global_context.clone(), &messages).await;

let settings = PostprocessSettings::new();
let mut settings = PostprocessSettings::new();
settings.max_files_n = max_files_n;
let (mut lines_in_files, mut lines_by_useful) = postprocess_rag_stage_3_6(
global_context.clone(),
messages,
Expand Down Expand Up @@ -542,11 +546,14 @@ pub async fn postprocess_rag_stage_7_9(
let filename = lineref.fref.cpath.to_string_lossy().to_string();

if !files_mentioned_set.contains(&filename) {
if files_mentioned_set.len() >= settings.max_files_n {
continue;
}
files_mentioned_set.insert(filename.clone());
files_mentioned_sequence.push(lineref.fref.cpath.clone());
if !single_file_mode {
ntokens += count_tokens(&tokenizer.read().unwrap(), &filename.as_str());
ntokens += 5; // any overhead: file_sep, new line, etc
ntokens += 5; // a margin for any overhead: file_sep, new line, etc
}
}
if tokens_count + ntokens > tokens_limit {
Expand Down Expand Up @@ -694,14 +701,15 @@ pub async fn run_at_commands(
);

let (messages_for_postprocessing, _) = execute_at_commands_in_query(&mut user_posted, &context, true, top_n).await;

let t0 = std::time::Instant::now();
let processed = postprocess_at_results2(
global_context.clone(),
messages_for_postprocessing,
tokenizer.clone(),
context_limit,
false,
top_n,
).await;
info!("postprocess_at_results2 {:.3}s", t0.elapsed().as_secs_f32());
if processed.len() > 0 {
Expand Down
2 changes: 2 additions & 0 deletions src/scratchpads/completion_single_file_fim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,14 @@ impl ScratchpadAbstract for SingleFileFIM {

info!(" -- post processing starts --");
let post_t0 = Instant::now();
let max_files_n = 10;
let postprocessed_messages = crate::scratchpads::chat_utils_rag::postprocess_at_results2(
self.global_context.clone(),
ast_messages,
self.t.tokenizer.clone(),
rag_tokens_n,
false,
max_files_n,
).await;

prompt = add_context_to_prompt(&self.t.context_format, &prompt, &self.fim_prefix, &postprocessed_messages, &language_id);
Expand Down

0 comments on commit 577e643

Please sign in to comment.