diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9eec367ff2..81c9587abe 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992]) - **breaking:** Move `Host` extractor to `axum-extra` ([#2956]) - **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) - **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978]) @@ -26,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2974]: https://github.com/tokio-rs/axum/pull/2974 [#2978]: https://github.com/tokio-rs/axum/pull/2978 +[#2992]: https://github.com/tokio-rs/axum/pull/2992 # 0.8.0 diff --git a/axum/Cargo.toml b/axum/Cargo.toml index e9e6c646f7..8d4fc6c241 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -117,7 +117,7 @@ quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index e77b8c78a8..b414f05725 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -208,12 +208,29 @@ impl Event { where T: serde::Serialize, { + struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>); + impl std::io::Write for IgnoreNewLines<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut last_split = 0; + for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) { + self.0.write_all(&buf[last_split..delimiter])?; + last_split = delimiter + 1; + } + self.0.write_all(&buf[last_split..])?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.0.flush() + } + } if self.flags.contains(EventFlags::HAS_DATA) { panic!("Called `EventBuilder::json_data` multiple times"); } self.buffer.extend_from_slice(b"data: "); - serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?; + serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data) + .map_err(axum_core::Error::new)?; self.buffer.put_u8(b'\n'); self.flags.insert(EventFlags::HAS_DATA); @@ -515,6 +532,7 @@ mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use futures_util::stream; + use serde_json::value::RawValue; use std::{collections::HashMap, convert::Infallible}; use tokio_stream::StreamExt as _; @@ -527,6 +545,18 @@ mod tests { assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n"); } + #[test] + fn valid_json_raw_value_chars_stripped() { + let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}"; + let json_raw_value_event = Event::default() + .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap()) + .unwrap(); + assert_eq!( + &*json_raw_value_event.finalize(), + format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes() + ); + } + #[crate::test] async fn basic() { let app = Router::new().route(