gglib_core/domain/
capabilities.rs

1//! Model capability detection and inference.
2//!
3//! Capabilities describe model constraints derived from chat templates.
4//! Absence of a capability MUST NOT trigger behavior changes.
5//!
6//! # Invariant
7//!
8//! Message rewriting is only permitted when model capabilities explicitly
9//! forbid the current message structure. Default behavior is pass-through.
10
11use bitflags::bitflags;
12use serde::{Deserialize, Serialize};
13
14bitflags! {
15    /// Model capabilities inferred from chat template analysis.
16    ///
17    /// These flags describe what the model's chat template can handle.
18    /// Absence means "we don't know" or "not needed", not "forbidden".
19    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20    #[repr(transparent)]
21    pub struct ModelCapabilities: u32 {
22        /// Model supports system role natively in its chat template.
23        ///
24        /// When set: system messages can be passed through unchanged.
25        /// When unset: system messages must be converted to user messages.
26        const SUPPORTS_SYSTEM_ROLE    = 0b0000_0001;
27
28        /// Model requires strict user/assistant alternation.
29        ///
30        /// When set: consecutive messages of same role must be merged.
31        /// When unset: message order can be arbitrary (OpenAI-style).
32        const REQUIRES_STRICT_TURNS   = 0b0000_0010;
33
34        /// Model supports tool/function calling.
35        ///
36        /// When set: tool_calls and tool role messages are supported.
37        /// When unset: tool functionality should not be used.
38        const SUPPORTS_TOOL_CALLS     = 0b0000_0100;
39
40        /// Model has reasoning/thinking capability.
41        ///
42        /// When set: model may produce <think> tags or reasoning_content.
43        /// When unset: model produces only standard responses.
44        const SUPPORTS_REASONING      = 0b0000_1000;
45    }
46}
47
48impl Default for ModelCapabilities {
49    /// Default capabilities represent "unknown" state.
50    ///
51    /// Models start with empty capabilities and must be explicitly inferred.
52    /// This prevents incorrect assumptions about model constraints.
53    fn default() -> Self {
54        Self::empty()
55    }
56}
57
58impl Serialize for ModelCapabilities {
59    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60    where
61        S: serde::Serializer,
62    {
63        self.bits().serialize(serializer)
64    }
65}
66
67impl<'de> Deserialize<'de> for ModelCapabilities {
68    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69    where
70        D: serde::Deserializer<'de>,
71    {
72        let bits = u32::deserialize(deserializer)?;
73        Ok(Self::from_bits_truncate(bits))
74    }
75}
76
77impl ModelCapabilities {
78    /// Check if model supports system role.
79    pub const fn supports_system_role(self) -> bool {
80        self.contains(Self::SUPPORTS_SYSTEM_ROLE)
81    }
82
83    /// Check if model requires strict user/assistant alternation.
84    pub const fn requires_strict_turns(self) -> bool {
85        self.contains(Self::REQUIRES_STRICT_TURNS)
86    }
87
88    /// Check if model supports tool/function calls.
89    pub const fn supports_tool_calls(self) -> bool {
90        self.contains(Self::SUPPORTS_TOOL_CALLS)
91    }
92
93    /// Check if model supports reasoning phases.
94    pub const fn supports_reasoning(self) -> bool {
95        self.contains(Self::SUPPORTS_REASONING)
96    }
97}
98
99/// Infer model capabilities from chat template Jinja source and model name.
100///
101/// Uses string heuristics to detect template constraints. Returns safe
102/// defaults if template is missing or unparseable.
103///
104/// # Detection Strategy
105///
106/// Two-layer approach:
107/// - **Layer 1 (Metadata)**: Check chat template for reliable signals (preferred)
108/// - **Layer 2 (Name Heuristics)**: Use model name patterns as fallback when metadata is missing
109///
110/// # Capabilities Detected
111///
112/// - System role: Looks for explicit rejection messages in template
113/// - Strict turns: Looks for alternation enforcement logic
114/// - Tool calling: Checks for `<tool_call>`, `if tools`, `function_call` patterns (metadata);
115///   falls back to model name patterns like "hermes", "functionary" (heuristic)
116/// - Reasoning: Checks for `<think>`, `<reasoning>`, `enable_thinking` (metadata);
117///   falls back to model name patterns like "deepseek-r1", "qwq", "o1" (heuristic)
118///
119/// # Fallback Behavior
120///
121/// Missing or unparseable templates default to empty capabilities (unknown state).
122pub fn infer_from_chat_template(
123    template: Option<&str>,
124    model_name: Option<&str>,
125) -> ModelCapabilities {
126    let mut caps = ModelCapabilities::empty();
127
128    // ─────────────────────────────────────────────────────────────────────────────
129    // Layer 1: Metadata-based detection (chat template analysis)
130    // ─────────────────────────────────────────────────────────────────────────────
131
132    let mut tool_detected_from_metadata = false;
133    let mut reasoning_detected_from_metadata = false;
134
135    if let Some(template) = template {
136        // Check for system role restrictions
137        // Mistral-style templates explicitly reject system role in error messages
138        let forbids_system = template.contains("Only user, assistant and tool roles are supported")
139            || template.contains("got system")
140            || template.contains("Raise exception for unsupported roles");
141
142        if forbids_system {
143            // Absence of SUPPORTS_SYSTEM_ROLE means transformation required
144        } else {
145            caps |= ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
146        }
147
148        // Check for strict alternation requirements
149        // Mistral-style templates enforce user/assistant alternation with modulo checks
150        let requires_alternation = template.contains("must alternate user and assistant")
151            || template.contains("conversation roles must alternate")
152            || template.contains("ns.index % 2");
153
154        if requires_alternation {
155            caps |= ModelCapabilities::REQUIRES_STRICT_TURNS;
156        }
157
158        // Detect tool calling support from template
159        let has_tool_patterns = template.contains("<tool_call>")
160            || template.contains("<|python_tag|>")
161            || template.contains("if tools")
162            || template.contains("tools is defined")
163            || template.contains("tool_calls")
164            || template.contains("function_call");
165
166        if has_tool_patterns {
167            caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
168            tool_detected_from_metadata = true;
169        }
170
171        // Detect reasoning/thinking support from template
172        let has_reasoning_patterns = template.contains("<think>")
173            || template.contains("</think>")
174            || template.contains("<reasoning>")
175            || template.contains("</reasoning>")
176            || template.contains("enable_thinking")
177            || template.contains("thinking_forced_open")
178            || template.contains("reasoning_content");
179
180        if has_reasoning_patterns {
181            caps |= ModelCapabilities::SUPPORTS_REASONING;
182            reasoning_detected_from_metadata = true;
183        }
184    }
185
186    // ─────────────────────────────────────────────────────────────────────────────
187    // Layer 2: Name-based heuristic fallback (when metadata is inconclusive)
188    // ─────────────────────────────────────────────────────────────────────────────
189    //
190    // Only use name patterns when chat template didn't provide clear evidence.
191    // This is less reliable but helps with models that have incomplete metadata.
192
193    if let Some(name) = model_name {
194        let name_lower = name.to_lowercase();
195
196        // Heuristic: Tool calling support based on model name
197        if !tool_detected_from_metadata {
198            let has_tool_name = name_lower.contains("hermes")
199                || name_lower.contains("functionary")
200                || name_lower.contains("firefunction")
201                || name_lower.contains("gorilla");
202
203            if has_tool_name {
204                caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
205            }
206        }
207
208        // Heuristic: Reasoning support based on model name
209        if !reasoning_detected_from_metadata {
210            let has_reasoning_name = name_lower.contains("deepseek-r1")
211                || name_lower.contains("qwq")
212                || name_lower.contains("-r1-")
213                || name_lower.contains("o1");
214
215            if has_reasoning_name {
216                caps |= ModelCapabilities::SUPPORTS_REASONING;
217            }
218        }
219    }
220
221    caps
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_default_capabilities() {
230        let caps = ModelCapabilities::default();
231        // Default is "unknown" - no capabilities set
232        assert!(caps.is_empty());
233        assert!(!caps.supports_system_role());
234        assert!(!caps.requires_strict_turns());
235        assert!(!caps.supports_tool_calls());
236        assert!(!caps.supports_reasoning());
237    }
238
239    #[test]
240    fn test_infer_openai_style() {
241        let template = r"
242            {% for message in messages %}
243                {{ message.role }}: {{ message.content }}
244            {% endfor %}
245        ";
246        let caps = infer_from_chat_template(Some(template), None);
247        assert!(caps.supports_system_role());
248        assert!(!caps.requires_strict_turns());
249    }
250
251    #[test]
252    fn test_infer_mistral_style() {
253        let template = r"
254            {% if message.role == 'system' %}
255                {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
256            {% endif %}
257            {% if (message['role'] == 'user') != (ns.index % 2 == 0) %}
258                {{ raise_exception('conversation roles must alternate user and assistant') }}
259            {% endif %}
260        ";
261        let caps = infer_from_chat_template(Some(template), None);
262        assert!(!caps.supports_system_role());
263        assert!(caps.requires_strict_turns());
264    }
265
266    #[test]
267    fn test_infer_missing_template() {
268        let caps = infer_from_chat_template(None, None);
269        // Missing template means unknown capabilities - no assumptions made
270        assert!(caps.is_empty());
271        assert!(!caps.supports_system_role());
272    }
273
274    #[test]
275    fn test_tool_calling_from_template() {
276        let template = r"
277            {% if tools %}
278                <tool_call>{{ message.tool_calls }}</tool_call>
279            {% endif %}
280        ";
281        let caps = infer_from_chat_template(Some(template), None);
282        assert!(caps.supports_tool_calls());
283    }
284
285    #[test]
286    fn test_reasoning_from_template() {
287        let template = r"
288            {% if enable_thinking %}
289                <think>{{ message.thinking }}</think>
290            {% endif %}
291        ";
292        let caps = infer_from_chat_template(Some(template), None);
293        assert!(caps.supports_reasoning());
294    }
295
296    #[test]
297    fn test_tool_calling_name_fallback() {
298        // No template, but model name suggests tool support
299        let caps = infer_from_chat_template(None, Some("hermes-2-pro-7b"));
300        assert!(caps.supports_tool_calls());
301    }
302
303    #[test]
304    fn test_reasoning_name_fallback() {
305        // No template, but model name suggests reasoning support
306        let caps = infer_from_chat_template(None, Some("deepseek-r1-lite"));
307        assert!(caps.supports_reasoning());
308    }
309
310    #[test]
311    fn test_metadata_plus_name_fallback() {
312        // Template present but has no tool markers - should still use name fallback
313        let template = "simple template with no tool markers";
314        let caps = infer_from_chat_template(Some(template), Some("hermes-model"));
315        // Name fallback should kick in because metadata didn't detect tools
316        assert!(caps.supports_tool_calls());
317    }
318
319    #[test]
320    fn test_metadata_detected_skips_name_fallback() {
321        // When metadata detects capability, name pattern is ignored
322        let template = "<tool_call>detected</tool_call>";
323        let caps = infer_from_chat_template(Some(template), Some("not-a-tool-model"));
324        // Metadata detected it, so tool support is enabled regardless of name
325        assert!(caps.supports_tool_calls());
326    }
327
328    #[test]
329    fn test_combined_detections() {
330        let template = r"
331            {% if tools %}<tool_call>{{ tool }}</tool_call>{% endif %}
332            <think>{{ reasoning }}</think>
333        ";
334        let caps = infer_from_chat_template(Some(template), None);
335        assert!(caps.supports_tool_calls());
336        assert!(caps.supports_reasoning());
337    }
338}
339
340// ─────────────────────────────────────────────────────────────────────────────
341// Message Transformation
342// ─────────────────────────────────────────────────────────────────────────────
343
344/// A chat message for transformation.
345#[derive(Debug, Clone, PartialEq, Eq)]
346pub struct ChatMessage {
347    pub role: String,
348    pub content: Option<String>,
349    pub tool_calls: Option<serde_json::Value>,
350}
351
352/// Merge consecutive system messages into a single message.
353///
354/// This is universally safe because:
355/// - No model template requires multiple system messages
356/// - Merging preserves all content with clear separation
357/// - It prevents errors in strict-turn templates (e.g., gemma3/medgemma)
358///
359/// # Arguments
360///
361/// * `messages` - The input chat messages
362///
363/// # Returns
364///
365/// Messages with consecutive system messages merged
366fn merge_consecutive_system_messages(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
367    if messages.is_empty() {
368        return messages;
369    }
370
371    let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
372
373    for msg in messages {
374        if let Some(last) = result.last_mut() {
375            if last.role == "system" && msg.role == "system" {
376                // Merge: append content with separator
377                let last_content = last.content.take().unwrap_or_default();
378                let new_content = msg.content.unwrap_or_default();
379
380                last.content = Some(if last_content.is_empty() {
381                    new_content
382                } else if new_content.is_empty() {
383                    last_content
384                } else {
385                    format!("{last_content}\n\n{new_content}")
386                });
387
388                continue; // Don't push, we merged into last
389            }
390        }
391        result.push(msg);
392    }
393
394    result
395}
396
397/// Transform chat messages based on model capabilities.
398///
399/// This is a pure function that applies capability-aware transformations:
400/// - Merges consecutive system messages (always, for all models)
401/// - Converts system messages to user messages when model doesn't support system role
402/// - Merges consecutive same-role messages when model requires strict alternation
403///
404/// # Invariant
405///
406/// Consecutive system messages are ALWAYS merged, regardless of capabilities.
407/// This prevents Jinja template errors in models with strict role alternation.
408///
409/// When capabilities are unknown (empty), only system message merging is applied.
410/// This prevents degrading standard models while ensuring universal compatibility.
411///
412/// # Arguments
413///
414/// * `messages` - The input chat messages to transform
415/// * `capabilities` - The model's capability flags
416///
417/// # Returns
418///
419/// Transformed messages suitable for the model's constraints
420pub fn transform_messages_for_capabilities(
421    mut messages: Vec<ChatMessage>,
422    capabilities: ModelCapabilities,
423) -> Vec<ChatMessage> {
424    // STEP 0 (ALWAYS): Merge consecutive system messages.
425    // This is safe for ALL models and prevents Jinja template errors
426    // in models with strict role alternation (e.g., gemma3/medgemma).
427    // Must run BEFORE the capabilities check to protect unknown models.
428    messages = merge_consecutive_system_messages(messages);
429
430    // Pass through if capabilities are unknown
431    if capabilities.is_empty() {
432        return messages;
433    }
434
435    // STEP 1: Transform system messages if the model doesn't support them
436    if !capabilities.contains(ModelCapabilities::SUPPORTS_SYSTEM_ROLE) {
437        for msg in &mut messages {
438            if msg.role == "system" {
439                msg.role = "user".to_string();
440                if let Some(content) = &mut msg.content {
441                    *content = format!("[System]: {content}");
442                }
443            }
444        }
445    }
446
447    // STEP 2: Merge consecutive same-role messages if strict turns are required
448    if capabilities.contains(ModelCapabilities::REQUIRES_STRICT_TURNS) {
449        let mut merged_messages = Vec::new();
450        for msg in messages {
451            if let Some(last) = merged_messages.last_mut() {
452                let last_msg: &mut ChatMessage = last;
453                // Only merge user/assistant messages to avoid tool-call ordering issues
454                let is_mergeable_role = msg.role == "user" || msg.role == "assistant";
455                if last_msg.role == msg.role
456                    && is_mergeable_role
457                    && last_msg.content.is_some()
458                    && msg.content.is_some()
459                    && last_msg.tool_calls.is_none()
460                    && msg.tool_calls.is_none()
461                {
462                    // Merge content
463                    if let (Some(last_content), Some(msg_content)) =
464                        (&mut last_msg.content, &msg.content)
465                    {
466                        last_content.push_str("\n\n");
467                        last_content.push_str(msg_content);
468                    }
469                    continue; // Skip adding this message separately
470                }
471            }
472            merged_messages.push(msg);
473        }
474        return merged_messages;
475    }
476
477    messages
478}
479
480#[cfg(test)]
481mod transform_tests {
482    use super::*;
483
484    #[test]
485    fn test_transform_unknown_passes_through_non_system() {
486        // Non-system messages pass through unchanged with empty capabilities
487        let messages = vec![
488            ChatMessage {
489                role: "user".to_string(),
490                content: Some("Hello".to_string()),
491                tool_calls: None,
492            },
493            ChatMessage {
494                role: "assistant".to_string(),
495                content: Some("Hi there".to_string()),
496                tool_calls: None,
497            },
498        ];
499        let original = messages.clone();
500        let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
501        assert_eq!(result, original);
502    }
503
504    #[test]
505    fn test_merges_consecutive_system_messages_always() {
506        // Even with empty capabilities, consecutive system messages should merge
507        let messages = vec![
508            ChatMessage {
509                role: "system".to_string(),
510                content: Some("You are a helpful assistant.".to_string()),
511                tool_calls: None,
512            },
513            ChatMessage {
514                role: "system".to_string(),
515                content: Some("WORKING_MEMORY:\n- task1 (ok): done".to_string()),
516                tool_calls: None,
517            },
518            ChatMessage {
519                role: "user".to_string(),
520                content: Some("Hello".to_string()),
521                tool_calls: None,
522            },
523        ];
524        let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
525
526        assert_eq!(result.len(), 2);
527        assert_eq!(result[0].role, "system");
528        assert_eq!(
529            result[0].content.as_deref(),
530            Some("You are a helpful assistant.\n\nWORKING_MEMORY:\n- task1 (ok): done")
531        );
532        assert_eq!(result[1].role, "user");
533    }
534
535    #[test]
536    fn test_merges_three_consecutive_system_messages() {
537        let messages = vec![
538            ChatMessage {
539                role: "system".to_string(),
540                content: Some("First.".to_string()),
541                tool_calls: None,
542            },
543            ChatMessage {
544                role: "system".to_string(),
545                content: Some("Second.".to_string()),
546                tool_calls: None,
547            },
548            ChatMessage {
549                role: "system".to_string(),
550                content: Some("Third.".to_string()),
551                tool_calls: None,
552            },
553        ];
554        let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
555
556        assert_eq!(result.len(), 1);
557        assert_eq!(
558            result[0].content.as_deref(),
559            Some("First.\n\nSecond.\n\nThird.")
560        );
561    }
562
563    #[test]
564    fn test_handles_empty_system_content() {
565        let messages = vec![
566            ChatMessage {
567                role: "system".to_string(),
568                content: Some(String::new()),
569                tool_calls: None,
570            },
571            ChatMessage {
572                role: "system".to_string(),
573                content: Some("Actual content".to_string()),
574                tool_calls: None,
575            },
576        ];
577        let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
578
579        assert_eq!(result.len(), 1);
580        assert_eq!(result[0].content.as_deref(), Some("Actual content"));
581    }
582
583    #[test]
584    fn test_transform_system_to_user() {
585        let messages = vec![
586            ChatMessage {
587                role: "system".to_string(),
588                content: Some("You are helpful".to_string()),
589                tool_calls: None,
590            },
591            ChatMessage {
592                role: "user".to_string(),
593                content: Some("Hello".to_string()),
594                tool_calls: None,
595            },
596        ];
597        // Use REQUIRES_STRICT_TURNS which doesn't support system but doesn't merge different roles
598        let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
599        let result = transform_messages_for_capabilities(messages, caps);
600        // System becomes user, both messages remain separate (user + user but different content)
601        assert_eq!(result.len(), 1); // They get merged because both are now "user"
602        assert_eq!(result[0].role, "user");
603        assert!(
604            result[0]
605                .content
606                .as_ref()
607                .unwrap()
608                .contains("[System]: You are helpful")
609        );
610        assert!(result[0].content.as_ref().unwrap().contains("Hello"));
611    }
612
613    #[test]
614    fn test_transform_preserves_system_when_supported() {
615        let messages = vec![ChatMessage {
616            role: "system".to_string(),
617            content: Some("You are helpful".to_string()),
618            tool_calls: None,
619        }];
620        let caps = ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
621        let result = transform_messages_for_capabilities(messages, caps);
622        assert_eq!(result[0].role, "system");
623        assert_eq!(result[0].content, Some("You are helpful".to_string()));
624    }
625
626    #[test]
627    fn test_transform_merges_consecutive_user_messages() {
628        let messages = vec![
629            ChatMessage {
630                role: "user".to_string(),
631                content: Some("First".to_string()),
632                tool_calls: None,
633            },
634            ChatMessage {
635                role: "user".to_string(),
636                content: Some("Second".to_string()),
637                tool_calls: None,
638            },
639        ];
640        let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
641        let result = transform_messages_for_capabilities(messages, caps);
642        assert_eq!(result.len(), 1);
643        assert_eq!(result[0].content, Some("First\n\nSecond".to_string()));
644    }
645
646    #[test]
647    fn test_transform_does_not_merge_tool_messages() {
648        let messages = vec![
649            ChatMessage {
650                role: "tool".to_string(),
651                content: Some("Result 1".to_string()),
652                tool_calls: None,
653            },
654            ChatMessage {
655                role: "tool".to_string(),
656                content: Some("Result 2".to_string()),
657                tool_calls: None,
658            },
659        ];
660        let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
661        let result = transform_messages_for_capabilities(messages, caps);
662        assert_eq!(result.len(), 2); // Should not merge tool messages
663    }
664
665    #[test]
666    fn test_transform_combined_system_and_merge() {
667        let messages = vec![
668            ChatMessage {
669                role: "system".to_string(),
670                content: Some("Be helpful".to_string()),
671                tool_calls: None,
672            },
673            ChatMessage {
674                role: "user".to_string(),
675                content: Some("First".to_string()),
676                tool_calls: None,
677            },
678            ChatMessage {
679                role: "user".to_string(),
680                content: Some("Second".to_string()),
681                tool_calls: None,
682            },
683        ];
684        let caps = ModelCapabilities::REQUIRES_STRICT_TURNS; // No system support + strict turns
685        let result = transform_messages_for_capabilities(messages, caps);
686        assert_eq!(result.len(), 1); // System→user + merge
687        assert_eq!(result[0].role, "user");
688        assert!(
689            result[0]
690                .content
691                .as_ref()
692                .unwrap()
693                .contains("[System]: Be helpful")
694        );
695        assert!(result[0].content.as_ref().unwrap().contains("First"));
696        assert!(result[0].content.as_ref().unwrap().contains("Second"));
697    }
698}