gglib_core/normalize/
stream.rs

1//! [`NormalizingStream`] — the single wrap point that canonicalises an
2//! LLM event stream.
3//!
4//! Adapters that implement [`crate::ports::LlmCompletionPort`] wrap the
5//! inner SSE-derived stream **once** with `NormalizingStream::new(inner,
6//! get_parser(&model.tags))`.  Every downstream consumer (Axum SSE, CLI,
7//! Tauri, the proxy, the agent loop) then sees a strict OpenAI-shaped
8//! sequence of [`LlmStreamEvent`] values, regardless of which dialect the
9//! underlying model speaks.
10//!
11//! ## Translation rules
12//!
13//! - `TextDelta` → routed through [`ToolCallParser::push_text`]; the parser
14//!   may strip dialect markup and synthesise [`LlmStreamEvent::ToolCallDelta`]
15//!   events for any extracted tool calls.
16//! - `ReasoningDelta` → routed through [`ToolCallParser::push_reasoning`]
17//!   symmetrically.
18//! - `ToolCallDelta` → forwarded unchanged (already conformant).  The
19//!   wrapper records the highest seen `index` so synthesised deltas use
20//!   non-colliding indices.
21//! - `PromptProgress` → forwarded unchanged.
22//! - `Done` → [`ToolCallParser::finish`] is called first, any flushed
23//!   bytes / tool calls / errors are emitted, **then** `Done` is forwarded
24//!   last.  The contract that every stream ends with exactly one `Done`
25//!   item is preserved.
26//!
27//! ## Errors
28//!
29//! Upstream `Err` items terminate the stream early (we propagate them
30//! verbatim).  Non-fatal normalization issues from the parser are surfaced
31//! as [`LlmStreamEvent::NormalizationError`] events; they do **not**
32//! terminate the stream.
33
34use std::collections::VecDeque;
35use std::pin::Pin;
36use std::task::{Context, Poll};
37
38use anyhow::Result;
39use futures_core::Stream;
40
41use super::parser::{ParserOutput, ToolCallParser};
42use crate::domain::agent::{LlmStreamEvent, ToolCall};
43
44type InnerStream = Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>;
45
46/// Stream adapter that runs every event through a [`ToolCallParser`] before
47/// re-emitting the normalized result.  See module docs.
48pub struct NormalizingStream {
49    inner: InnerStream,
50    parser: Box<dyn ToolCallParser>,
51    /// Events ready to emit on the next poll.  A single upstream event can
52    /// expand to many downstream events (e.g. `Done` flushes parser state
53    /// before propagating).
54    queued: VecDeque<LlmStreamEvent>,
55    /// Lowest tool-call index that is safe to use for a synthesised delta.
56    /// Bumped past every upstream `index` we observe so downstream
57    /// collectors can use indices as keys without collision.
58    next_index: usize,
59    /// `true` once we've forwarded the upstream `Done` (or upstream ended
60    /// or errored).  Subsequent polls return `None`.
61    terminated: bool,
62}
63
64impl NormalizingStream {
65    /// Wrap `inner` so every event is normalized through `parser`.
66    #[must_use]
67    pub fn new(inner: InnerStream, parser: Box<dyn ToolCallParser>) -> Self {
68        Self {
69            inner,
70            parser,
71            queued: VecDeque::new(),
72            next_index: 0,
73            terminated: false,
74        }
75    }
76
77    /// Translate one parser output batch into the queued event sequence.
78    fn enqueue_parser_output(&mut self, mut out: ParserOutput) {
79        if !out.forward_text.is_empty() {
80            // Strip stray `<think>` / `</think>` boundary tags from text
81            // content.  Reasoning models (e.g. Qwen3) send their chain-of-
82            // thought in `reasoning_content` SSE fields but leak the closing
83            // `</think>` marker into the regular `content` field when
84            // transitioning back to output mode.  These tags carry no
85            // semantic meaning for the client and produce visible artefacts
86            // (e.g. `</think>` appearing verbatim in Zed's chat pane).
87            let text = std::mem::take(&mut out.forward_text);
88            let text = text.replace("</think>", "").replace("<think>", "");
89            if !text.is_empty() {
90                self.queued
91                    .push_back(LlmStreamEvent::TextDelta { content: text });
92            }
93        }
94        if !out.forward_reasoning.is_empty() {
95            self.queued.push_back(LlmStreamEvent::ReasoningDelta {
96                content: std::mem::take(&mut out.forward_reasoning),
97            });
98        }
99        for ToolCall {
100            id,
101            name,
102            arguments,
103        } in out.tool_calls
104        {
105            let index = self.next_index;
106            self.next_index += 1;
107            self.queued.push_back(LlmStreamEvent::ToolCallDelta {
108                index,
109                id: Some(id),
110                name: Some(name),
111                arguments: Some(arguments.to_string()),
112            });
113        }
114        for err in out.errors {
115            self.queued.push_back(LlmStreamEvent::NormalizationError {
116                kind: err.kind,
117                raw: err.raw,
118            });
119        }
120    }
121
122    /// Process one upstream event and queue the resulting downstream events.
123    fn handle_upstream(&mut self, event: LlmStreamEvent) {
124        match event {
125            LlmStreamEvent::TextDelta { content } => {
126                let out = self.parser.push_text(&content);
127                self.enqueue_parser_output(out);
128            }
129            LlmStreamEvent::ReasoningDelta { content } => {
130                let out = self.parser.push_reasoning(&content);
131                self.enqueue_parser_output(out);
132            }
133            LlmStreamEvent::ToolCallDelta {
134                index,
135                id,
136                name,
137                arguments,
138            } => {
139                if index >= self.next_index {
140                    self.next_index = index + 1;
141                }
142                self.queued.push_back(LlmStreamEvent::ToolCallDelta {
143                    index,
144                    id,
145                    name,
146                    arguments,
147                });
148            }
149            LlmStreamEvent::PromptProgress { .. } | LlmStreamEvent::NormalizationError { .. } => {
150                self.queued.push_back(event);
151            }
152            LlmStreamEvent::Done { finish_reason } => {
153                let out = self.parser.finish();
154                self.enqueue_parser_output(out);
155                // Qwen3.5 (and some other models) emit tool_calls in the
156                // stream but finish with `finish_reason: "stop"` instead of
157                // the required `"tool_calls"`.  Clients such as Zed check
158                // finish_reason to decide whether to dispatch tool results;
159                // a wrong value causes the conversation to hang.
160                let finish_reason = if finish_reason == "stop" && self.next_index > 0 {
161                    "tool_calls".to_owned()
162                } else {
163                    finish_reason
164                };
165                self.queued
166                    .push_back(LlmStreamEvent::Done { finish_reason });
167                self.terminated = true;
168            }
169        }
170    }
171}
172
173impl Stream for NormalizingStream {
174    type Item = Result<LlmStreamEvent>;
175
176    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177        loop {
178            if let Some(ev) = self.queued.pop_front() {
179                return Poll::Ready(Some(Ok(ev)));
180            }
181            if self.terminated {
182                return Poll::Ready(None);
183            }
184            match self.inner.as_mut().poll_next(cx) {
185                Poll::Pending => return Poll::Pending,
186                Poll::Ready(Some(Ok(event))) => {
187                    self.handle_upstream(event);
188                    // Loop to drain `queued` (or poll inner again if empty).
189                }
190                Poll::Ready(Some(Err(e))) => {
191                    self.terminated = true;
192                    return Poll::Ready(Some(Err(e)));
193                }
194                Poll::Ready(None) => {
195                    // Upstream ended without a `Done`.  Flush any held-back
196                    // parser state so no bytes are lost, then end.
197                    let out = self.parser.finish();
198                    self.enqueue_parser_output(out);
199                    self.terminated = true;
200                    if let Some(ev) = self.queued.pop_front() {
201                        return Poll::Ready(Some(Ok(ev)));
202                    }
203                    return Poll::Ready(None);
204                }
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::normalize::{registry::get_parser, tags};
214    use std::task::Poll;
215
216    /// Minimal hand-rolled stream that yields a fixed sequence of events.
217    struct VecStream {
218        items: VecDeque<Result<LlmStreamEvent>>,
219    }
220
221    impl VecStream {
222        fn new(items: Vec<Result<LlmStreamEvent>>) -> Self {
223            Self {
224                items: items.into(),
225            }
226        }
227    }
228
229    impl Stream for VecStream {
230        type Item = Result<LlmStreamEvent>;
231        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232            Poll::Ready(self.items.pop_front())
233        }
234    }
235
236    fn drain(mut s: NormalizingStream) -> Vec<LlmStreamEvent> {
237        // Poll synchronously with std's no-op waker.  Our test stream is
238        // always Ready, so we never observe Pending.
239        let waker = std::task::Waker::noop();
240        let mut cx = Context::from_waker(waker);
241        let mut out = Vec::new();
242        loop {
243            match Pin::new(&mut s).poll_next(&mut cx) {
244                Poll::Ready(Some(Ok(ev))) => out.push(ev),
245                Poll::Ready(Some(Err(e))) => panic!("unexpected error: {e}"),
246                Poll::Ready(None) => return out,
247                Poll::Pending => panic!("test stream returned Pending"),
248            }
249        }
250    }
251
252    fn wrap(events: Vec<LlmStreamEvent>, qwen: bool) -> NormalizingStream {
253        let inner: InnerStream = Box::pin(VecStream::new(events.into_iter().map(Ok).collect()));
254        let parser = if qwen {
255            get_parser(&[tags::FORMAT_QWEN_XML.to_owned()])
256        } else {
257            get_parser(&[])
258        };
259        NormalizingStream::new(inner, parser)
260    }
261
262    #[test]
263    fn standard_parser_is_passthrough() {
264        let events = vec![
265            LlmStreamEvent::TextDelta {
266                content: "hello".into(),
267            },
268            LlmStreamEvent::Done {
269                finish_reason: "stop".into(),
270            },
271        ];
272        let out = drain(wrap(events.clone(), false));
273        assert_eq!(out, events);
274    }
275
276    #[test]
277    fn qwen_xml_in_text_is_extracted_to_tool_call_delta() {
278        let events = vec![
279            LlmStreamEvent::TextDelta {
280                content: r#"hi <tool_call>{"name":"foo","arguments":{"x":1}}</tool_call> done"#
281                    .into(),
282            },
283            LlmStreamEvent::Done {
284                finish_reason: "tool_calls".into(),
285            },
286        ];
287        let out = drain(wrap(events, true));
288        // Expect: TextDelta("hi  done"), ToolCallDelta, Done.
289        assert_eq!(out.len(), 3);
290        assert!(matches!(
291            &out[0],
292            LlmStreamEvent::TextDelta { content } if content == "hi  done"
293        ));
294        match &out[1] {
295            LlmStreamEvent::ToolCallDelta {
296                index,
297                id,
298                name,
299                arguments,
300            } => {
301                assert_eq!(*index, 0);
302                assert_eq!(id.as_deref(), Some("call_qwen_0"));
303                assert_eq!(name.as_deref(), Some("foo"));
304                assert_eq!(arguments.as_deref(), Some(r#"{"x":1}"#));
305            }
306            other => panic!("expected ToolCallDelta, got {other:?}"),
307        }
308        assert!(matches!(out[2], LlmStreamEvent::Done { .. }));
309    }
310
311    #[test]
312    fn qwen_xml_in_reasoning_is_extracted_and_text_clean() {
313        let events = vec![
314            LlmStreamEvent::ReasoningDelta {
315                content: r#"think <tool_call>{"name":"foo","arguments":{}}</tool_call> end"#.into(),
316            },
317            LlmStreamEvent::Done {
318                finish_reason: "tool_calls".into(),
319            },
320        ];
321        let out = drain(wrap(events, true));
322        assert_eq!(out.len(), 3);
323        assert!(matches!(
324            &out[0],
325            LlmStreamEvent::ReasoningDelta { content } if content == "think  end"
326        ));
327        assert!(matches!(out[1], LlmStreamEvent::ToolCallDelta { .. }));
328        assert!(matches!(out[2], LlmStreamEvent::Done { .. }));
329    }
330
331    #[test]
332    fn synthesised_index_does_not_collide_with_upstream() {
333        let events = vec![
334            LlmStreamEvent::ToolCallDelta {
335                index: 0,
336                id: Some("native".into()),
337                name: Some("nat".into()),
338                arguments: Some("{}".into()),
339            },
340            LlmStreamEvent::TextDelta {
341                content: r#"<tool_call>{"name":"foo","arguments":{}}</tool_call>"#.into(),
342            },
343            LlmStreamEvent::Done {
344                finish_reason: "stop".into(),
345            },
346        ];
347        let out = drain(wrap(events, true));
348        // Native delta + synthesised delta + Done = 3.
349        assert_eq!(out.len(), 3);
350        let LlmStreamEvent::ToolCallDelta { index: idx0, .. } = &out[0] else {
351            panic!()
352        };
353        let LlmStreamEvent::ToolCallDelta { index: idx1, .. } = &out[1] else {
354            panic!()
355        };
356        assert_eq!(*idx0, 0);
357        assert_eq!(*idx1, 1);
358    }
359
360    #[test]
361    fn unclosed_tag_at_done_emits_normalization_error_then_done() {
362        let events = vec![
363            LlmStreamEvent::TextDelta {
364                content: r#"<tool_call>{"name":"foo""#.into(),
365            },
366            LlmStreamEvent::Done {
367                finish_reason: "stop".into(),
368            },
369        ];
370        let out = drain(wrap(events, true));
371        // Expect at least: NormalizationError, Done.
372        assert!(matches!(out.last(), Some(LlmStreamEvent::Done { .. })));
373        assert!(
374            out.iter()
375                .any(|e| matches!(e, LlmStreamEvent::NormalizationError { .. }))
376        );
377    }
378
379    #[test]
380    fn upstream_ends_without_done_flushes_parser() {
381        // No Done at all — wrapper should still terminate cleanly and
382        // surface any held-back text.
383        let events = vec![LlmStreamEvent::TextDelta {
384            content: "<tool".into(),
385        }];
386        let out = drain(wrap(events, true));
387        assert_eq!(out.len(), 1);
388        assert!(matches!(
389            &out[0],
390            LlmStreamEvent::TextDelta { content } if content == "<tool"
391        ));
392    }
393
394    /// Qwen3.5 emits `tool_calls` in the stream but finishes with
395    /// `finish_reason: "stop"` instead of `"tool_calls"`.  The normalizer
396    /// must correct this so clients that gate tool dispatch on `finish_reason`
397    /// (e.g. Zed) do not hang.
398    #[test]
399    fn finish_reason_corrected_to_tool_calls_when_tool_calls_seen() {
400        let events = vec![
401            LlmStreamEvent::ToolCallDelta {
402                index: 0,
403                id: Some("call_0".into()),
404                name: Some("read_file".into()),
405                arguments: Some(r#"{"path":"/tmp/x"}"#.into()),
406            },
407            LlmStreamEvent::Done {
408                finish_reason: "stop".into(), // wrong — model bug
409            },
410        ];
411        let out = drain(wrap(events, false));
412        assert_eq!(out.len(), 2);
413        match &out[1] {
414            LlmStreamEvent::Done { finish_reason } => {
415                assert_eq!(finish_reason, "tool_calls");
416            }
417            other => panic!("expected Done, got {other:?}"),
418        }
419    }
420
421    /// When no tool calls were emitted, `finish_reason: "stop"` must be
422    /// left unchanged.
423    #[test]
424    fn finish_reason_stop_unchanged_when_no_tool_calls() {
425        let events = vec![
426            LlmStreamEvent::TextDelta {
427                content: "hello".into(),
428            },
429            LlmStreamEvent::Done {
430                finish_reason: "stop".into(),
431            },
432        ];
433        let out = drain(wrap(events, false));
434        match &out[1] {
435            LlmStreamEvent::Done { finish_reason } => {
436                assert_eq!(finish_reason, "stop");
437            }
438            other => panic!("expected Done, got {other:?}"),
439        }
440    }
441
442    /// Stray `</think>` closing tags emitted in text content by reasoning
443    /// models (e.g. Qwen3) must be stripped before reaching the client.
444    #[test]
445    fn stray_close_think_tag_stripped_from_text() {
446        let events = vec![
447            LlmStreamEvent::TextDelta {
448                content: "</think>\n\n".into(),
449            },
450            LlmStreamEvent::TextDelta {
451                content: "actual answer".into(),
452            },
453            LlmStreamEvent::Done {
454                finish_reason: "stop".into(),
455            },
456        ];
457        let out = drain(wrap(events, false));
458        // First delta should be dropped entirely (only whitespace after stripping).
459        // Second delta passes through unchanged.
460        let texts: Vec<_> = out
461            .iter()
462            .filter_map(|e| {
463                if let LlmStreamEvent::TextDelta { content } = e {
464                    Some(content.as_str())
465                } else {
466                    None
467                }
468            })
469            .collect();
470        assert!(
471            !texts.iter().any(|t| t.contains("</think>")),
472            "found </think> in output: {texts:?}"
473        );
474        assert!(texts.iter().any(|t| t.contains("actual answer")));
475    }
476
477    /// `<think>` open tags should also be stripped from text content.
478    #[test]
479    fn stray_open_think_tag_stripped_from_text() {
480        let events = vec![
481            LlmStreamEvent::TextDelta {
482                content: "<think>spurious</think>real text".into(),
483            },
484            LlmStreamEvent::Done {
485                finish_reason: "stop".into(),
486            },
487        ];
488        let out = drain(wrap(events, false));
489        let texts: Vec<_> = out
490            .iter()
491            .filter_map(|e| {
492                if let LlmStreamEvent::TextDelta { content } = e {
493                    Some(content.as_str())
494                } else {
495                    None
496                }
497            })
498            .collect();
499        assert!(
500            !texts
501                .iter()
502                .any(|t| t.contains("<think>") || t.contains("</think>"))
503        );
504        assert!(texts.iter().any(|t| t.contains("real text")));
505    }
506}