1use serde_json::{Value, json};
29
30use crate::LlmStreamEvent;
31
32#[derive(Debug, Clone)]
37pub struct SseEncoder {
38 pub id: String,
40 pub model: String,
42 pub created: u64,
44}
45
46impl SseEncoder {
47 #[must_use]
49 pub fn new(id: impl Into<String>, model: impl Into<String>, created: u64) -> Self {
50 Self {
51 id: id.into(),
52 model: model.into(),
53 created,
54 }
55 }
56
57 #[must_use]
68 pub fn encode(&self, event: &LlmStreamEvent) -> Option<String> {
69 match event {
70 LlmStreamEvent::TextDelta { content } => Some(self.frame(&json!({
71 "index": 0,
72 "delta": { "content": content },
73 "finish_reason": Value::Null,
74 }))),
75 LlmStreamEvent::ReasoningDelta { content } => Some(self.frame(&json!({
76 "index": 0,
77 "delta": { "reasoning_content": content },
78 "finish_reason": Value::Null,
79 }))),
80 LlmStreamEvent::ToolCallDelta {
81 index,
82 id,
83 name,
84 arguments,
85 } => {
86 let mut tc = json!({ "index": index });
87 if let Some(id) = id {
88 tc["id"] = json!(id);
89 tc["type"] = json!("function");
92 }
93 let mut function = json!({});
94 if let Some(name) = name {
95 function["name"] = json!(name);
96 }
97 if let Some(arguments) = arguments {
98 function["arguments"] = json!(arguments);
99 }
100 if function.as_object().is_some_and(|o| !o.is_empty()) {
101 tc["function"] = function;
102 }
103 Some(self.frame(&json!({
104 "index": 0,
105 "delta": { "tool_calls": [tc] },
106 "finish_reason": Value::Null,
107 })))
108 }
109 LlmStreamEvent::PromptProgress {
110 processed,
111 total,
112 cached,
113 time_ms,
114 } => {
115 let value = json!({
117 "id": self.id,
118 "object": "chat.completion.chunk",
119 "created": self.created,
120 "model": self.model,
121 "prompt_progress": {
122 "processed": processed,
123 "total": total,
124 "cache": cached,
125 "time_ms": time_ms,
126 },
127 });
128 Some(format!("data: {value}\n\n"))
129 }
130 LlmStreamEvent::Done { finish_reason } => {
131 let chunk = self.frame(&json!({
132 "index": 0,
133 "delta": {},
134 "finish_reason": finish_reason,
135 }));
136 Some(format!("{chunk}data: [DONE]\n\n"))
137 }
138 LlmStreamEvent::NormalizationError { .. } => None,
139 }
140 }
141
142 fn frame(&self, choice: &Value) -> String {
144 let value = json!({
145 "id": self.id,
146 "object": "chat.completion.chunk",
147 "created": self.created,
148 "model": self.model,
149 "choices": [choice],
150 });
151 format!("data: {value}\n\n")
152 }
153}
154
155#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::normalize::NormalizationErrorKind;
163
164 fn enc() -> SseEncoder {
165 SseEncoder::new("chatcmpl-1", "test-model", 1_729_000_000)
166 }
167
168 fn parse_data_frame(out: &str) -> serde_json::Value {
169 let line = out.lines().next().expect("non-empty output");
170 let payload = line.strip_prefix("data: ").expect("data: prefix");
171 serde_json::from_str(payload).expect("valid JSON")
172 }
173
174 #[test]
175 fn text_delta_encodes_to_content_chunk() {
176 let out = enc()
177 .encode(&LlmStreamEvent::TextDelta {
178 content: "hello".to_owned(),
179 })
180 .expect("frame");
181 assert!(out.starts_with("data: "));
182 assert!(out.ends_with("\n\n"));
183 let v = parse_data_frame(&out);
184 assert_eq!(v["object"], "chat.completion.chunk");
185 assert_eq!(v["id"], "chatcmpl-1");
186 assert_eq!(v["model"], "test-model");
187 assert_eq!(v["choices"][0]["delta"]["content"], "hello");
188 assert!(v["choices"][0]["finish_reason"].is_null());
189 }
190
191 #[test]
192 fn reasoning_delta_encodes_to_reasoning_content_chunk() {
193 let out = enc()
194 .encode(&LlmStreamEvent::ReasoningDelta {
195 content: "think".to_owned(),
196 })
197 .expect("frame");
198 let v = parse_data_frame(&out);
199 assert_eq!(v["choices"][0]["delta"]["reasoning_content"], "think");
200 }
201
202 #[test]
203 fn tool_call_delta_first_frame_includes_id_and_type() {
204 let out = enc()
205 .encode(&LlmStreamEvent::ToolCallDelta {
206 index: 0,
207 id: Some("tc1".to_owned()),
208 name: Some("search".to_owned()),
209 arguments: Some(r#"{"q":"r"}"#.to_owned()),
210 })
211 .expect("frame");
212 let v = parse_data_frame(&out);
213 let tc = &v["choices"][0]["delta"]["tool_calls"][0];
214 assert_eq!(tc["index"], 0);
215 assert_eq!(tc["id"], "tc1");
216 assert_eq!(tc["type"], "function");
217 assert_eq!(tc["function"]["name"], "search");
218 assert_eq!(tc["function"]["arguments"], r#"{"q":"r"}"#);
219 }
220
221 #[test]
222 fn tool_call_delta_continuation_omits_id_and_type() {
223 let out = enc()
224 .encode(&LlmStreamEvent::ToolCallDelta {
225 index: 0,
226 id: None,
227 name: None,
228 arguments: Some("more".to_owned()),
229 })
230 .expect("frame");
231 let v = parse_data_frame(&out);
232 let tc = &v["choices"][0]["delta"]["tool_calls"][0];
233 assert!(tc.get("id").is_none(), "id must be omitted on continuation");
234 assert!(
235 tc.get("type").is_none(),
236 "type must be omitted on continuation"
237 );
238 assert_eq!(tc["function"]["arguments"], "more");
239 }
240
241 #[test]
242 fn done_event_emits_finish_chunk_and_done_sentinel() {
243 let out = enc()
244 .encode(&LlmStreamEvent::Done {
245 finish_reason: "stop".to_owned(),
246 })
247 .expect("frame");
248 let lines: Vec<&str> = out.lines().filter(|l| !l.is_empty()).collect();
250 assert_eq!(lines.len(), 2, "Done emits two data: lines");
251 let v: serde_json::Value =
252 serde_json::from_str(lines[0].strip_prefix("data: ").unwrap()).unwrap();
253 assert_eq!(v["choices"][0]["finish_reason"], "stop");
254 assert_eq!(lines[1], "data: [DONE]");
255 }
256
257 #[test]
258 fn prompt_progress_encodes_to_top_level_field() {
259 let out = enc()
260 .encode(&LlmStreamEvent::PromptProgress {
261 processed: 2,
262 total: 8,
263 cached: 1,
264 time_ms: 100,
265 })
266 .expect("frame");
267 let v = parse_data_frame(&out);
268 assert_eq!(v["prompt_progress"]["processed"], 2);
269 assert_eq!(v["prompt_progress"]["total"], 8);
270 assert_eq!(v["prompt_progress"]["cache"], 1);
271 assert_eq!(v["prompt_progress"]["time_ms"], 100);
272 assert!(v.get("choices").is_none());
273 }
274
275 #[test]
276 fn normalization_error_is_suppressed() {
277 let out = enc().encode(&LlmStreamEvent::NormalizationError {
278 kind: NormalizationErrorKind::MalformedToolCallJson {
279 raw: "<tool_call>oops".to_owned(),
280 },
281 raw: "<tool_call>oops".to_owned(),
282 });
283 assert!(
284 out.is_none(),
285 "NormalizationError must never reach the wire"
286 );
287 }
288}