diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index 988ac3bcd9eb0..5abb17799b3ac 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -4197,6 +4197,7 @@ async fn test_collaborating_with_lsp_progress_updates_and_diagnostics_ordering( }], }, ); + executor.run_until_parked(); } fake_language_server.notify::(&lsp::ProgressParams { token: lsp::NumberOrString::String("the-disk-based-token".to_string()), diff --git a/crates/editor/src/test/editor_lsp_test_context.rs b/crates/editor/src/test/editor_lsp_test_context.rs index 23e37a1267bdb..87be96afc7f1c 100644 --- a/crates/editor/src/test/editor_lsp_test_context.rs +++ b/crates/editor/src/test/editor_lsp_test_context.rs @@ -315,12 +315,12 @@ impl EditorLspTestContext { pub fn handle_request( &self, - mut handler: F, + handler: F, ) -> futures::channel::mpsc::UnboundedReceiver<()> where T: 'static + request::Request, T::Params: 'static + Send, - F: 'static + Send + FnMut(lsp::Url, T::Params, gpui::AsyncAppContext) -> Fut, + F: 'static + Send + Sync + Fn(lsp::Url, T::Params, gpui::AsyncAppContext) -> Fut, Fut: 'static + Send + Future>, { let url = self.buffer_lsp_url.clone(); diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index e9fa1caac2398..3e6f536ff4f61 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -45,7 +45,8 @@ const CONTENT_LEN_HEADER: &str = "Content-Length: "; const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2); const SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); -type NotificationHandler = Box, Value, AsyncAppContext)>; +type NotificationHandler = + Arc, Value, AsyncAppContext) -> Task<()>>; type ResponseHandler = Box)>; type IoHandler = Box; @@ -528,11 +529,15 @@ impl LanguageServer { while let Some(msg) = input_handler.notifications_channel.next().await { { - let mut notification_handlers = notification_handlers.lock(); - if let Some(handler) = notification_handlers.get_mut(msg.method.as_str()) { - handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone()); + let handler = { + notification_handlers + .lock() + .get(msg.method.as_str()) + .cloned() + }; + if let Some(handler) = handler { + handler(msg.id, msg.params.unwrap_or(Value::Null), cx.clone()).await; } else { - drop(notification_handlers); on_unhandled_notification(msg); } } @@ -890,7 +895,7 @@ impl LanguageServer { pub fn on_notification(&self, f: F) -> Subscription where T: notification::Notification, - F: 'static + Send + FnMut(T::Params, AsyncAppContext), + F: 'static + Send + Sync + Fn(T::Params, AsyncAppContext), { self.on_custom_notification(T::METHOD, f) } @@ -903,7 +908,7 @@ impl LanguageServer { where T: request::Request, T::Params: 'static + Send, - F: 'static + FnMut(T::Params, AsyncAppContext) -> Fut + Send, + F: 'static + Fn(T::Params, AsyncAppContext) -> Fut + Send + Sync, Fut: 'static + Future>, { self.on_custom_request(T::METHOD, f) @@ -939,17 +944,26 @@ impl LanguageServer { } #[must_use] - fn on_custom_notification(&self, method: &'static str, mut f: F) -> Subscription + fn on_custom_notification(&self, method: &'static str, f: F) -> Subscription where - F: 'static + FnMut(Params, AsyncAppContext) + Send, - Params: DeserializeOwned, + F: 'static + Fn(Params, AsyncAppContext) + Send + Sync, + Params: DeserializeOwned + Send + 'static, { + let callback = Arc::new(f); let prev_handler = self.notification_handlers.lock().insert( method, - Box::new(move |_, params, cx| { - if let Some(params) = serde_json::from_value(params).log_err() { - f(params, cx); - } + Arc::new(move |_, params, cx| { + let callback = callback.clone(); + + cx.spawn(move |cx| async move { + if let Some(params) = cx + .background_executor() + .spawn(async move { serde_json::from_value(params).log_err() }) + .await + { + callback(params, cx); + } + }) }), ); assert!( @@ -963,64 +977,75 @@ impl LanguageServer { } #[must_use] - fn on_custom_request(&self, method: &'static str, mut f: F) -> Subscription + fn on_custom_request(&self, method: &'static str, f: F) -> Subscription where - F: 'static + FnMut(Params, AsyncAppContext) -> Fut + Send, + F: 'static + Fn(Params, AsyncAppContext) -> Fut + Send + Sync, Fut: 'static + Future>, Params: DeserializeOwned + Send + 'static, Res: Serialize, { let outbound_tx = self.outbound_tx.clone(); + let f = Arc::new(f); let prev_handler = self.notification_handlers.lock().insert( method, - Box::new(move |id, params, cx| { + Arc::new(move |id, params, cx| { if let Some(id) = id { - match serde_json::from_value(params) { - Ok(params) => { - let response = f(params, cx.clone()); - cx.foreground_executor() - .spawn({ - let outbound_tx = outbound_tx.clone(); - async move { - let response = match response.await { - Ok(result) => Response { - jsonrpc: JSON_RPC_VERSION, - id, - value: LspResult::Ok(Some(result)), - }, - Err(error) => Response { - jsonrpc: JSON_RPC_VERSION, - id, - value: LspResult::Error(Some(Error { - message: error.to_string(), - })), - }, - }; - if let Some(response) = - serde_json::to_string(&response).log_err() - { - outbound_tx.try_send(response).ok(); - } + let f = f.clone(); + let deserialized_params = cx + .background_executor() + .spawn(async move { serde_json::from_value(params) }); + + cx.spawn({ + let outbound_tx = outbound_tx.clone(); + move |cx| async move { + match deserialized_params.await { + Ok(params) => { + let response = f(params, cx.clone()); + let response = match response.await { + Ok(result) => Response { + jsonrpc: JSON_RPC_VERSION, + id, + value: LspResult::Ok(Some(result)), + }, + Err(error) => Response { + jsonrpc: JSON_RPC_VERSION, + id, + value: LspResult::Error(Some(Error { + message: error.to_string(), + })), + }, + }; + if let Some(response) = + serde_json::to_string(&response).log_err() + { + outbound_tx.try_send(response).ok(); } - }) - .detach(); - } - - Err(error) => { - log::error!("error deserializing {} request: {:?}", method, error); - let response = AnyResponse { - jsonrpc: JSON_RPC_VERSION, - id, - result: None, - error: Some(Error { - message: error.to_string(), - }), - }; - if let Some(response) = serde_json::to_string(&response).log_err() { - outbound_tx.try_send(response).ok(); + } + Err(error) => { + log::error!( + "error deserializing {} request: {:?}", + method, + error + ); + let response = AnyResponse { + jsonrpc: JSON_RPC_VERSION, + id, + result: None, + error: Some(Error { + message: error.to_string(), + }), + }; + if let Some(response) = + serde_json::to_string(&response).log_err() + { + outbound_tx.try_send(response).ok(); + } + } } } - } + }) + } else { + Task::ready(()) } }), ); @@ -1425,12 +1450,12 @@ impl FakeLanguageServer { /// Registers a handler for a specific kind of request. Removes any existing handler for specified request type. pub fn handle_request( &self, - mut handler: F, + handler: F, ) -> futures::channel::mpsc::UnboundedReceiver<()> where T: 'static + request::Request, T::Params: 'static + Send, - F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext) -> Fut, + F: 'static + Send + Sync + Fn(T::Params, gpui::AsyncAppContext) -> Fut, Fut: 'static + Send + Future>, { let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded(); @@ -1454,12 +1479,12 @@ impl FakeLanguageServer { /// Registers a handler for a specific kind of notification. Removes any existing handler for specified notification type. pub fn handle_notification( &self, - mut handler: F, + handler: F, ) -> futures::channel::mpsc::UnboundedReceiver<()> where T: 'static + notification::Notification, T::Params: 'static + Send, - F: 'static + Send + FnMut(T::Params, gpui::AsyncAppContext), + F: 'static + Send + Sync + Fn(T::Params, gpui::AsyncAppContext), { let (handled_tx, handled_rx) = futures::channel::mpsc::unbounded(); self.server.remove_notification_handler::();