1use std::collections::VecDeque;
35use std::pin::Pin;
36use std::task::{Context, Poll};
37
38use anyhow::Result;
39use futures_core::Stream;
40
41use super::parser::{ParserOutput, ToolCallParser};
42use crate::domain::agent::{LlmStreamEvent, ToolCall};
43
44type InnerStream = Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>;
45
46pub struct NormalizingStream {
49 inner: InnerStream,
50 parser: Box<dyn ToolCallParser>,
51 queued: VecDeque<LlmStreamEvent>,
55 next_index: usize,
59 terminated: bool,
62}
63
64impl NormalizingStream {
65 #[must_use]
67 pub fn new(inner: InnerStream, parser: Box<dyn ToolCallParser>) -> Self {
68 Self {
69 inner,
70 parser,
71 queued: VecDeque::new(),
72 next_index: 0,
73 terminated: false,
74 }
75 }
76
77 fn enqueue_parser_output(&mut self, mut out: ParserOutput) {
79 if !out.forward_text.is_empty() {
80 let text = std::mem::take(&mut out.forward_text);
88 let text = text.replace("</think>", "").replace("<think>", "");
89 if !text.is_empty() {
90 self.queued
91 .push_back(LlmStreamEvent::TextDelta { content: text });
92 }
93 }
94 if !out.forward_reasoning.is_empty() {
95 self.queued.push_back(LlmStreamEvent::ReasoningDelta {
96 content: std::mem::take(&mut out.forward_reasoning),
97 });
98 }
99 for ToolCall {
100 id,
101 name,
102 arguments,
103 } in out.tool_calls
104 {
105 let index = self.next_index;
106 self.next_index += 1;
107 self.queued.push_back(LlmStreamEvent::ToolCallDelta {
108 index,
109 id: Some(id),
110 name: Some(name),
111 arguments: Some(arguments.to_string()),
112 });
113 }
114 for err in out.errors {
115 self.queued.push_back(LlmStreamEvent::NormalizationError {
116 kind: err.kind,
117 raw: err.raw,
118 });
119 }
120 }
121
122 fn handle_upstream(&mut self, event: LlmStreamEvent) {
124 match event {
125 LlmStreamEvent::TextDelta { content } => {
126 let out = self.parser.push_text(&content);
127 self.enqueue_parser_output(out);
128 }
129 LlmStreamEvent::ReasoningDelta { content } => {
130 let out = self.parser.push_reasoning(&content);
131 self.enqueue_parser_output(out);
132 }
133 LlmStreamEvent::ToolCallDelta {
134 index,
135 id,
136 name,
137 arguments,
138 } => {
139 if index >= self.next_index {
140 self.next_index = index + 1;
141 }
142 self.queued.push_back(LlmStreamEvent::ToolCallDelta {
143 index,
144 id,
145 name,
146 arguments,
147 });
148 }
149 LlmStreamEvent::PromptProgress { .. } | LlmStreamEvent::NormalizationError { .. } => {
150 self.queued.push_back(event);
151 }
152 LlmStreamEvent::Done { finish_reason } => {
153 let out = self.parser.finish();
154 self.enqueue_parser_output(out);
155 let finish_reason = if finish_reason == "stop" && self.next_index > 0 {
161 "tool_calls".to_owned()
162 } else {
163 finish_reason
164 };
165 self.queued
166 .push_back(LlmStreamEvent::Done { finish_reason });
167 self.terminated = true;
168 }
169 }
170 }
171}
172
173impl Stream for NormalizingStream {
174 type Item = Result<LlmStreamEvent>;
175
176 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177 loop {
178 if let Some(ev) = self.queued.pop_front() {
179 return Poll::Ready(Some(Ok(ev)));
180 }
181 if self.terminated {
182 return Poll::Ready(None);
183 }
184 match self.inner.as_mut().poll_next(cx) {
185 Poll::Pending => return Poll::Pending,
186 Poll::Ready(Some(Ok(event))) => {
187 self.handle_upstream(event);
188 }
190 Poll::Ready(Some(Err(e))) => {
191 self.terminated = true;
192 return Poll::Ready(Some(Err(e)));
193 }
194 Poll::Ready(None) => {
195 let out = self.parser.finish();
198 self.enqueue_parser_output(out);
199 self.terminated = true;
200 if let Some(ev) = self.queued.pop_front() {
201 return Poll::Ready(Some(Ok(ev)));
202 }
203 return Poll::Ready(None);
204 }
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::normalize::{registry::get_parser, tags};
214 use std::task::Poll;
215
216 struct VecStream {
218 items: VecDeque<Result<LlmStreamEvent>>,
219 }
220
221 impl VecStream {
222 fn new(items: Vec<Result<LlmStreamEvent>>) -> Self {
223 Self {
224 items: items.into(),
225 }
226 }
227 }
228
229 impl Stream for VecStream {
230 type Item = Result<LlmStreamEvent>;
231 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232 Poll::Ready(self.items.pop_front())
233 }
234 }
235
236 fn drain(mut s: NormalizingStream) -> Vec<LlmStreamEvent> {
237 let waker = std::task::Waker::noop();
240 let mut cx = Context::from_waker(waker);
241 let mut out = Vec::new();
242 loop {
243 match Pin::new(&mut s).poll_next(&mut cx) {
244 Poll::Ready(Some(Ok(ev))) => out.push(ev),
245 Poll::Ready(Some(Err(e))) => panic!("unexpected error: {e}"),
246 Poll::Ready(None) => return out,
247 Poll::Pending => panic!("test stream returned Pending"),
248 }
249 }
250 }
251
252 fn wrap(events: Vec<LlmStreamEvent>, qwen: bool) -> NormalizingStream {
253 let inner: InnerStream = Box::pin(VecStream::new(events.into_iter().map(Ok).collect()));
254 let parser = if qwen {
255 get_parser(&[tags::FORMAT_QWEN_XML.to_owned()])
256 } else {
257 get_parser(&[])
258 };
259 NormalizingStream::new(inner, parser)
260 }
261
262 #[test]
263 fn standard_parser_is_passthrough() {
264 let events = vec![
265 LlmStreamEvent::TextDelta {
266 content: "hello".into(),
267 },
268 LlmStreamEvent::Done {
269 finish_reason: "stop".into(),
270 },
271 ];
272 let out = drain(wrap(events.clone(), false));
273 assert_eq!(out, events);
274 }
275
276 #[test]
277 fn qwen_xml_in_text_is_extracted_to_tool_call_delta() {
278 let events = vec![
279 LlmStreamEvent::TextDelta {
280 content: r#"hi <tool_call>{"name":"foo","arguments":{"x":1}}</tool_call> done"#
281 .into(),
282 },
283 LlmStreamEvent::Done {
284 finish_reason: "tool_calls".into(),
285 },
286 ];
287 let out = drain(wrap(events, true));
288 assert_eq!(out.len(), 3);
290 assert!(matches!(
291 &out[0],
292 LlmStreamEvent::TextDelta { content } if content == "hi done"
293 ));
294 match &out[1] {
295 LlmStreamEvent::ToolCallDelta {
296 index,
297 id,
298 name,
299 arguments,
300 } => {
301 assert_eq!(*index, 0);
302 assert_eq!(id.as_deref(), Some("call_qwen_0"));
303 assert_eq!(name.as_deref(), Some("foo"));
304 assert_eq!(arguments.as_deref(), Some(r#"{"x":1}"#));
305 }
306 other => panic!("expected ToolCallDelta, got {other:?}"),
307 }
308 assert!(matches!(out[2], LlmStreamEvent::Done { .. }));
309 }
310
311 #[test]
312 fn qwen_xml_in_reasoning_is_extracted_and_text_clean() {
313 let events = vec![
314 LlmStreamEvent::ReasoningDelta {
315 content: r#"think <tool_call>{"name":"foo","arguments":{}}</tool_call> end"#.into(),
316 },
317 LlmStreamEvent::Done {
318 finish_reason: "tool_calls".into(),
319 },
320 ];
321 let out = drain(wrap(events, true));
322 assert_eq!(out.len(), 3);
323 assert!(matches!(
324 &out[0],
325 LlmStreamEvent::ReasoningDelta { content } if content == "think end"
326 ));
327 assert!(matches!(out[1], LlmStreamEvent::ToolCallDelta { .. }));
328 assert!(matches!(out[2], LlmStreamEvent::Done { .. }));
329 }
330
331 #[test]
332 fn synthesised_index_does_not_collide_with_upstream() {
333 let events = vec![
334 LlmStreamEvent::ToolCallDelta {
335 index: 0,
336 id: Some("native".into()),
337 name: Some("nat".into()),
338 arguments: Some("{}".into()),
339 },
340 LlmStreamEvent::TextDelta {
341 content: r#"<tool_call>{"name":"foo","arguments":{}}</tool_call>"#.into(),
342 },
343 LlmStreamEvent::Done {
344 finish_reason: "stop".into(),
345 },
346 ];
347 let out = drain(wrap(events, true));
348 assert_eq!(out.len(), 3);
350 let LlmStreamEvent::ToolCallDelta { index: idx0, .. } = &out[0] else {
351 panic!()
352 };
353 let LlmStreamEvent::ToolCallDelta { index: idx1, .. } = &out[1] else {
354 panic!()
355 };
356 assert_eq!(*idx0, 0);
357 assert_eq!(*idx1, 1);
358 }
359
360 #[test]
361 fn unclosed_tag_at_done_emits_normalization_error_then_done() {
362 let events = vec![
363 LlmStreamEvent::TextDelta {
364 content: r#"<tool_call>{"name":"foo""#.into(),
365 },
366 LlmStreamEvent::Done {
367 finish_reason: "stop".into(),
368 },
369 ];
370 let out = drain(wrap(events, true));
371 assert!(matches!(out.last(), Some(LlmStreamEvent::Done { .. })));
373 assert!(
374 out.iter()
375 .any(|e| matches!(e, LlmStreamEvent::NormalizationError { .. }))
376 );
377 }
378
379 #[test]
380 fn upstream_ends_without_done_flushes_parser() {
381 let events = vec![LlmStreamEvent::TextDelta {
384 content: "<tool".into(),
385 }];
386 let out = drain(wrap(events, true));
387 assert_eq!(out.len(), 1);
388 assert!(matches!(
389 &out[0],
390 LlmStreamEvent::TextDelta { content } if content == "<tool"
391 ));
392 }
393
394 #[test]
399 fn finish_reason_corrected_to_tool_calls_when_tool_calls_seen() {
400 let events = vec![
401 LlmStreamEvent::ToolCallDelta {
402 index: 0,
403 id: Some("call_0".into()),
404 name: Some("read_file".into()),
405 arguments: Some(r#"{"path":"/tmp/x"}"#.into()),
406 },
407 LlmStreamEvent::Done {
408 finish_reason: "stop".into(), },
410 ];
411 let out = drain(wrap(events, false));
412 assert_eq!(out.len(), 2);
413 match &out[1] {
414 LlmStreamEvent::Done { finish_reason } => {
415 assert_eq!(finish_reason, "tool_calls");
416 }
417 other => panic!("expected Done, got {other:?}"),
418 }
419 }
420
421 #[test]
424 fn finish_reason_stop_unchanged_when_no_tool_calls() {
425 let events = vec![
426 LlmStreamEvent::TextDelta {
427 content: "hello".into(),
428 },
429 LlmStreamEvent::Done {
430 finish_reason: "stop".into(),
431 },
432 ];
433 let out = drain(wrap(events, false));
434 match &out[1] {
435 LlmStreamEvent::Done { finish_reason } => {
436 assert_eq!(finish_reason, "stop");
437 }
438 other => panic!("expected Done, got {other:?}"),
439 }
440 }
441
442 #[test]
445 fn stray_close_think_tag_stripped_from_text() {
446 let events = vec![
447 LlmStreamEvent::TextDelta {
448 content: "</think>\n\n".into(),
449 },
450 LlmStreamEvent::TextDelta {
451 content: "actual answer".into(),
452 },
453 LlmStreamEvent::Done {
454 finish_reason: "stop".into(),
455 },
456 ];
457 let out = drain(wrap(events, false));
458 let texts: Vec<_> = out
461 .iter()
462 .filter_map(|e| {
463 if let LlmStreamEvent::TextDelta { content } = e {
464 Some(content.as_str())
465 } else {
466 None
467 }
468 })
469 .collect();
470 assert!(
471 !texts.iter().any(|t| t.contains("</think>")),
472 "found </think> in output: {texts:?}"
473 );
474 assert!(texts.iter().any(|t| t.contains("actual answer")));
475 }
476
477 #[test]
479 fn stray_open_think_tag_stripped_from_text() {
480 let events = vec![
481 LlmStreamEvent::TextDelta {
482 content: "<think>spurious</think>real text".into(),
483 },
484 LlmStreamEvent::Done {
485 finish_reason: "stop".into(),
486 },
487 ];
488 let out = drain(wrap(events, false));
489 let texts: Vec<_> = out
490 .iter()
491 .filter_map(|e| {
492 if let LlmStreamEvent::TextDelta { content } = e {
493 Some(content.as_str())
494 } else {
495 None
496 }
497 })
498 .collect();
499 assert!(
500 !texts
501 .iter()
502 .any(|t| t.contains("<think>") || t.contains("</think>"))
503 );
504 assert!(texts.iter().any(|t| t.contains("real text")));
505 }
506}