Skip to content

Commit

Permalink
zeta: Send staff edit predictions through Cloudflare Workers (#23847)
Browse files Browse the repository at this point in the history
This PR makes it so staff edit predictions now go through Cloudflare
Workers instead of going to the LLM service.

This will allow us to dogfood the new LLM worker to make sure it is
working as expected.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Jan 29, 2025
1 parent e594397 commit 8603a90
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions crates/zeta/src/zeta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use anyhow::{anyhow, Context as _, Result};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use collections::{HashMap, HashSet, VecDeque};
use feature_flags::FeatureFlagAppExt as _;
use futures::AsyncReadExt;
use gpui::{
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
Expand Down Expand Up @@ -298,7 +299,7 @@ impl Zeta {
perform_predict_edits: F,
) -> Task<Result<Option<InlineCompletion>>>
where
F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsParams) -> R + 'static,
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
{
let snapshot = self.report_changes_for_buffer(buffer, cx);
Expand All @@ -313,6 +314,7 @@ impl Zeta {

let client = self.client.clone();
let llm_token = self.llm_token.clone();
let is_staff = cx.is_staff();

cx.spawn(|_, cx| async move {
let request_sent_at = Instant::now();
Expand Down Expand Up @@ -348,7 +350,7 @@ impl Zeta {
outline: Some(input_outline.clone()),
};

let response = perform_predict_edits(client, llm_token, body).await?;
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;

let output_excerpt = response.output_excerpt;
log::debug!("completion response: {}", output_excerpt);
Expand Down Expand Up @@ -515,7 +517,7 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;

self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response)))
}

pub fn request_completion(
Expand All @@ -530,6 +532,7 @@ and then another
fn perform_predict_edits(
client: Arc<Client>,
llm_token: LlmApiToken,
is_staff: bool,
body: PredictEditsParams,
) -> impl Future<Output = Result<PredictEditsResponse>> {
async move {
Expand All @@ -538,14 +541,19 @@ and then another
let mut did_retry = false;

loop {
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
.uri(
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder = if is_staff {
request_builder.uri(
"https://llm-worker-production.zed-industries.workers.dev/predict_edits",
)
} else {
request_builder.uri(
http_client
.build_zed_llm_url("/predict_edits", &[])?
.as_ref(),
)
};
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.body(serde_json::to_string(&body)?.into())?;
Expand Down

0 comments on commit 8603a90

Please sign in to comment.