Skip to content

Commit

Permalink
avoid dangling clients in reload ext command implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pr2502 committed Mar 17, 2024
1 parent d780ba4 commit 189397d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 61 deletions.
42 changes: 11 additions & 31 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub async fn process(
LspMuxOptions::PROTOCOL_VERSION,
);

debug!(?options, "lspmux initialization");
match options.method {
ext::Request::Connect { server, args, cwd } => {
connect(
Expand All @@ -73,7 +74,7 @@ pub async fn process(
.await
}
ext::Request::Status {} => status(instance_map, writer).await,
ext::Request::Reload { cwd } => reload(port, cwd, instance_map, writer).await,
ext::Request::Reload { cwd } => reload(cwd, instance_map, writer).await,
}
}

Expand Down Expand Up @@ -117,25 +118,28 @@ async fn status(
}

async fn reload(
port: u16,
cwd: String,
instance_map: Arc<Mutex<InstanceMap>>,
mut writer: LspWriter<OwnedWriteHalf>,
) -> Result<()> {
let mut receiver = if let Some(instance) = instance_map.lock().await.get_by_cwd(&cwd) {
let (client, receiver) = Client::new(port);
instance.add_client(client).await;
if let Some(instance) = instance_map.lock().await.get_by_cwd(&cwd) {
instance
.send_message(Message::Request(Request {
jsonrpc: Version,
method: "rust-analyzer/reloadWorkspace".into(),
params: Value::Null,
id: RequestId::Number(0).tag(Tag::Port(port)),
id: RequestId::Number(0).tag(Tag::Drop),
}))
.await
.ok()
.context("instance closed")?;
receiver

writer
.write_message(&Message::ResponseSuccess(ResponseSuccess::null(
RequestId::Number(0),
)))
.await
.context("writing response")?;
} else {
writer
.write_message(&Message::ResponseError(ResponseError {
Expand All @@ -150,30 +154,6 @@ async fn reload(
.await
.context("writing response")?;
debug!(?cwd, "no instance found for path");
return Ok(());
};

if let Some(response) = receiver.recv().await {
let message = match response
.into_response()
.context("received message was not a response")?
{
Ok(res) => Message::ResponseSuccess(ResponseSuccess {
jsonrpc: Version,
result: res.result,
id: RequestId::Number(0),
}),
Err(res) => Message::ResponseError(ResponseError {
jsonrpc: Version,
error: res.error,
id: RequestId::Number(0),
}),
};

writer
.write_message(&message)
.await
.context("writing response")?;
}

Ok(())
Expand Down
46 changes: 29 additions & 17 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,31 +636,43 @@ async fn stdout_task(instance: Arc<Instance>, mut reader: LspReader<BufReader<Ch
Message::ResponseSuccess(mut res) => {
// Forward successful response to the right client based on the
// Request ID tag.
if let (Some(Tag::Port(port)), id) = res.id.untag() {
res.id = id;
if let Some(client) = clients.get(&port) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
match res.id.untag() {
(Some(Tag::Port(port)), id) => {
res.id = id;
if let Some(client) = clients.get(&port) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
}
}
(Some(Tag::Drop), _) => {
// Drop the message
}
_ => {
warn!(?res, "ignoring improperly tagged server response")
}
} else {
warn!(?res, "ignoring improperly tagged server response")
}
}

Message::ResponseError(mut res) => {
// Forward the error response to the right client based on the
// Request ID tag.
if let (Some(Tag::Port(port)), id) = res.id.untag() {
warn!(?res, "server responded with error");
res.id = id;
if let Some(client) = clients.get(&port) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
match res.id.untag() {
(Some(Tag::Port(port)), id) => {
warn!(?res, "server responded with error");
res.id = id;
if let Some(client) = clients.get(&port) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
}
}
(Some(Tag::Drop), _) => {
// Drop the message
}
_ => {
warn!(?res, "ignoring improperly tagged server response")
}
} else {
warn!(?res, "ignoring improperly tagged server response")
}
}

Expand Down
13 changes: 13 additions & 0 deletions src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@
use serde_derive::{Deserialize, Serialize};

macro_rules! impl_json_debug {
( $($type:ty),* $(,)? ) => {
$(
impl ::std::fmt::Debug for $type {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
let json = ::serde_json::to_string(self).expect("BUG: invalid message");
f.write_str(&json)
}
}
)*
};
}

pub mod ext;
pub mod jsonrpc;
pub mod transport;
Expand Down
4 changes: 4 additions & 0 deletions src/lsp/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ pub struct LspMuxOptions {
pub method: Request,
}

impl_json_debug! {
LspMuxOptions,
}

impl LspMuxOptions {
/// Protocol version
///
Expand Down
13 changes: 0 additions & 13 deletions src/lsp/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,6 @@ impl From<ResponseError> for Message {
}
}

macro_rules! impl_json_debug {
( $($type:ty),* $(,)? ) => {
$(
impl fmt::Debug for $type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let json = serde_json::to_string(self).expect("BUG: invalid message");
f.write_str(&json)
}
}
)*
};
}

impl_json_debug! {
Message, Request, Notification, ResponseSuccess, ResponseError,
}
Expand Down

0 comments on commit 189397d

Please sign in to comment.