1use bitflags::bitflags;
12use serde::{Deserialize, Serialize};
13
14bitflags! {
15 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20 #[repr(transparent)]
21 pub struct ModelCapabilities: u32 {
22 const SUPPORTS_SYSTEM_ROLE = 0b0000_0001;
27
28 const REQUIRES_STRICT_TURNS = 0b0000_0010;
33
34 const SUPPORTS_TOOL_CALLS = 0b0000_0100;
39
40 const SUPPORTS_REASONING = 0b0000_1000;
45 }
46}
47
48impl Default for ModelCapabilities {
49 fn default() -> Self {
54 Self::empty()
55 }
56}
57
58impl Serialize for ModelCapabilities {
59 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 S: serde::Serializer,
62 {
63 self.bits().serialize(serializer)
64 }
65}
66
67impl<'de> Deserialize<'de> for ModelCapabilities {
68 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69 where
70 D: serde::Deserializer<'de>,
71 {
72 let bits = u32::deserialize(deserializer)?;
73 Ok(Self::from_bits_truncate(bits))
74 }
75}
76
77impl ModelCapabilities {
78 pub const fn supports_system_role(self) -> bool {
80 self.contains(Self::SUPPORTS_SYSTEM_ROLE)
81 }
82
83 pub const fn requires_strict_turns(self) -> bool {
85 self.contains(Self::REQUIRES_STRICT_TURNS)
86 }
87
88 pub const fn supports_tool_calls(self) -> bool {
90 self.contains(Self::SUPPORTS_TOOL_CALLS)
91 }
92
93 pub const fn supports_reasoning(self) -> bool {
95 self.contains(Self::SUPPORTS_REASONING)
96 }
97}
98
99pub fn infer_from_chat_template(
123 template: Option<&str>,
124 model_name: Option<&str>,
125) -> ModelCapabilities {
126 let mut caps = ModelCapabilities::empty();
127
128 let mut tool_detected_from_metadata = false;
133 let mut reasoning_detected_from_metadata = false;
134
135 if let Some(template) = template {
136 let forbids_system = template.contains("Only user, assistant and tool roles are supported")
139 || template.contains("got system")
140 || template.contains("Raise exception for unsupported roles");
141
142 if forbids_system {
143 } else {
145 caps |= ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
146 }
147
148 let requires_alternation = template.contains("must alternate user and assistant")
151 || template.contains("conversation roles must alternate")
152 || template.contains("ns.index % 2");
153
154 if requires_alternation {
155 caps |= ModelCapabilities::REQUIRES_STRICT_TURNS;
156 }
157
158 let has_tool_patterns = template.contains("<tool_call>")
160 || template.contains("<|python_tag|>")
161 || template.contains("if tools")
162 || template.contains("tools is defined")
163 || template.contains("tool_calls")
164 || template.contains("function_call");
165
166 if has_tool_patterns {
167 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
168 tool_detected_from_metadata = true;
169 }
170
171 let has_reasoning_patterns = template.contains("<think>")
173 || template.contains("</think>")
174 || template.contains("<reasoning>")
175 || template.contains("</reasoning>")
176 || template.contains("enable_thinking")
177 || template.contains("thinking_forced_open")
178 || template.contains("reasoning_content");
179
180 if has_reasoning_patterns {
181 caps |= ModelCapabilities::SUPPORTS_REASONING;
182 reasoning_detected_from_metadata = true;
183 }
184 }
185
186 if let Some(name) = model_name {
194 let name_lower = name.to_lowercase();
195
196 if !tool_detected_from_metadata {
198 let has_tool_name = name_lower.contains("hermes")
199 || name_lower.contains("functionary")
200 || name_lower.contains("firefunction")
201 || name_lower.contains("gorilla");
202
203 if has_tool_name {
204 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
205 }
206 }
207
208 if !reasoning_detected_from_metadata {
210 let has_reasoning_name = name_lower.contains("deepseek-r1")
211 || name_lower.contains("qwq")
212 || name_lower.contains("-r1-")
213 || name_lower.contains("o1");
214
215 if has_reasoning_name {
216 caps |= ModelCapabilities::SUPPORTS_REASONING;
217 }
218 }
219 }
220
221 caps
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_default_capabilities() {
230 let caps = ModelCapabilities::default();
231 assert!(caps.is_empty());
233 assert!(!caps.supports_system_role());
234 assert!(!caps.requires_strict_turns());
235 assert!(!caps.supports_tool_calls());
236 assert!(!caps.supports_reasoning());
237 }
238
239 #[test]
240 fn test_infer_openai_style() {
241 let template = r"
242 {% for message in messages %}
243 {{ message.role }}: {{ message.content }}
244 {% endfor %}
245 ";
246 let caps = infer_from_chat_template(Some(template), None);
247 assert!(caps.supports_system_role());
248 assert!(!caps.requires_strict_turns());
249 }
250
251 #[test]
252 fn test_infer_mistral_style() {
253 let template = r"
254 {% if message.role == 'system' %}
255 {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
256 {% endif %}
257 {% if (message['role'] == 'user') != (ns.index % 2 == 0) %}
258 {{ raise_exception('conversation roles must alternate user and assistant') }}
259 {% endif %}
260 ";
261 let caps = infer_from_chat_template(Some(template), None);
262 assert!(!caps.supports_system_role());
263 assert!(caps.requires_strict_turns());
264 }
265
266 #[test]
267 fn test_infer_missing_template() {
268 let caps = infer_from_chat_template(None, None);
269 assert!(caps.is_empty());
271 assert!(!caps.supports_system_role());
272 }
273
274 #[test]
275 fn test_tool_calling_from_template() {
276 let template = r"
277 {% if tools %}
278 <tool_call>{{ message.tool_calls }}</tool_call>
279 {% endif %}
280 ";
281 let caps = infer_from_chat_template(Some(template), None);
282 assert!(caps.supports_tool_calls());
283 }
284
285 #[test]
286 fn test_reasoning_from_template() {
287 let template = r"
288 {% if enable_thinking %}
289 <think>{{ message.thinking }}</think>
290 {% endif %}
291 ";
292 let caps = infer_from_chat_template(Some(template), None);
293 assert!(caps.supports_reasoning());
294 }
295
296 #[test]
297 fn test_tool_calling_name_fallback() {
298 let caps = infer_from_chat_template(None, Some("hermes-2-pro-7b"));
300 assert!(caps.supports_tool_calls());
301 }
302
303 #[test]
304 fn test_reasoning_name_fallback() {
305 let caps = infer_from_chat_template(None, Some("deepseek-r1-lite"));
307 assert!(caps.supports_reasoning());
308 }
309
310 #[test]
311 fn test_metadata_plus_name_fallback() {
312 let template = "simple template with no tool markers";
314 let caps = infer_from_chat_template(Some(template), Some("hermes-model"));
315 assert!(caps.supports_tool_calls());
317 }
318
319 #[test]
320 fn test_metadata_detected_skips_name_fallback() {
321 let template = "<tool_call>detected</tool_call>";
323 let caps = infer_from_chat_template(Some(template), Some("not-a-tool-model"));
324 assert!(caps.supports_tool_calls());
326 }
327
328 #[test]
329 fn test_combined_detections() {
330 let template = r"
331 {% if tools %}<tool_call>{{ tool }}</tool_call>{% endif %}
332 <think>{{ reasoning }}</think>
333 ";
334 let caps = infer_from_chat_template(Some(template), None);
335 assert!(caps.supports_tool_calls());
336 assert!(caps.supports_reasoning());
337 }
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
346pub struct ChatMessage {
347 pub role: String,
348 pub content: Option<String>,
349 pub tool_calls: Option<serde_json::Value>,
350}
351
352fn merge_consecutive_system_messages(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
367 if messages.is_empty() {
368 return messages;
369 }
370
371 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
372
373 for msg in messages {
374 if let Some(last) = result.last_mut() {
375 if last.role == "system" && msg.role == "system" {
376 let last_content = last.content.take().unwrap_or_default();
378 let new_content = msg.content.unwrap_or_default();
379
380 last.content = Some(if last_content.is_empty() {
381 new_content
382 } else if new_content.is_empty() {
383 last_content
384 } else {
385 format!("{last_content}\n\n{new_content}")
386 });
387
388 continue; }
390 }
391 result.push(msg);
392 }
393
394 result
395}
396
397pub fn transform_messages_for_capabilities(
421 mut messages: Vec<ChatMessage>,
422 capabilities: ModelCapabilities,
423) -> Vec<ChatMessage> {
424 messages = merge_consecutive_system_messages(messages);
429
430 if capabilities.is_empty() {
432 return messages;
433 }
434
435 if !capabilities.contains(ModelCapabilities::SUPPORTS_SYSTEM_ROLE) {
437 for msg in &mut messages {
438 if msg.role == "system" {
439 msg.role = "user".to_string();
440 if let Some(content) = &mut msg.content {
441 *content = format!("[System]: {content}");
442 }
443 }
444 }
445 }
446
447 if capabilities.contains(ModelCapabilities::REQUIRES_STRICT_TURNS) {
449 let mut merged_messages = Vec::new();
450 for msg in messages {
451 if let Some(last) = merged_messages.last_mut() {
452 let last_msg: &mut ChatMessage = last;
453 let is_mergeable_role = msg.role == "user" || msg.role == "assistant";
455 if last_msg.role == msg.role
456 && is_mergeable_role
457 && last_msg.content.is_some()
458 && msg.content.is_some()
459 && last_msg.tool_calls.is_none()
460 && msg.tool_calls.is_none()
461 {
462 if let (Some(last_content), Some(msg_content)) =
464 (&mut last_msg.content, &msg.content)
465 {
466 last_content.push_str("\n\n");
467 last_content.push_str(msg_content);
468 }
469 continue; }
471 }
472 merged_messages.push(msg);
473 }
474 return merged_messages;
475 }
476
477 messages
478}
479
480#[cfg(test)]
481mod transform_tests {
482 use super::*;
483
484 #[test]
485 fn test_transform_unknown_passes_through_non_system() {
486 let messages = vec![
488 ChatMessage {
489 role: "user".to_string(),
490 content: Some("Hello".to_string()),
491 tool_calls: None,
492 },
493 ChatMessage {
494 role: "assistant".to_string(),
495 content: Some("Hi there".to_string()),
496 tool_calls: None,
497 },
498 ];
499 let original = messages.clone();
500 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
501 assert_eq!(result, original);
502 }
503
504 #[test]
505 fn test_merges_consecutive_system_messages_always() {
506 let messages = vec![
508 ChatMessage {
509 role: "system".to_string(),
510 content: Some("You are a helpful assistant.".to_string()),
511 tool_calls: None,
512 },
513 ChatMessage {
514 role: "system".to_string(),
515 content: Some("WORKING_MEMORY:\n- task1 (ok): done".to_string()),
516 tool_calls: None,
517 },
518 ChatMessage {
519 role: "user".to_string(),
520 content: Some("Hello".to_string()),
521 tool_calls: None,
522 },
523 ];
524 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
525
526 assert_eq!(result.len(), 2);
527 assert_eq!(result[0].role, "system");
528 assert_eq!(
529 result[0].content.as_deref(),
530 Some("You are a helpful assistant.\n\nWORKING_MEMORY:\n- task1 (ok): done")
531 );
532 assert_eq!(result[1].role, "user");
533 }
534
535 #[test]
536 fn test_merges_three_consecutive_system_messages() {
537 let messages = vec![
538 ChatMessage {
539 role: "system".to_string(),
540 content: Some("First.".to_string()),
541 tool_calls: None,
542 },
543 ChatMessage {
544 role: "system".to_string(),
545 content: Some("Second.".to_string()),
546 tool_calls: None,
547 },
548 ChatMessage {
549 role: "system".to_string(),
550 content: Some("Third.".to_string()),
551 tool_calls: None,
552 },
553 ];
554 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
555
556 assert_eq!(result.len(), 1);
557 assert_eq!(
558 result[0].content.as_deref(),
559 Some("First.\n\nSecond.\n\nThird.")
560 );
561 }
562
563 #[test]
564 fn test_handles_empty_system_content() {
565 let messages = vec![
566 ChatMessage {
567 role: "system".to_string(),
568 content: Some(String::new()),
569 tool_calls: None,
570 },
571 ChatMessage {
572 role: "system".to_string(),
573 content: Some("Actual content".to_string()),
574 tool_calls: None,
575 },
576 ];
577 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
578
579 assert_eq!(result.len(), 1);
580 assert_eq!(result[0].content.as_deref(), Some("Actual content"));
581 }
582
583 #[test]
584 fn test_transform_system_to_user() {
585 let messages = vec![
586 ChatMessage {
587 role: "system".to_string(),
588 content: Some("You are helpful".to_string()),
589 tool_calls: None,
590 },
591 ChatMessage {
592 role: "user".to_string(),
593 content: Some("Hello".to_string()),
594 tool_calls: None,
595 },
596 ];
597 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
599 let result = transform_messages_for_capabilities(messages, caps);
600 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
603 assert!(
604 result[0]
605 .content
606 .as_ref()
607 .unwrap()
608 .contains("[System]: You are helpful")
609 );
610 assert!(result[0].content.as_ref().unwrap().contains("Hello"));
611 }
612
613 #[test]
614 fn test_transform_preserves_system_when_supported() {
615 let messages = vec![ChatMessage {
616 role: "system".to_string(),
617 content: Some("You are helpful".to_string()),
618 tool_calls: None,
619 }];
620 let caps = ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
621 let result = transform_messages_for_capabilities(messages, caps);
622 assert_eq!(result[0].role, "system");
623 assert_eq!(result[0].content, Some("You are helpful".to_string()));
624 }
625
626 #[test]
627 fn test_transform_merges_consecutive_user_messages() {
628 let messages = vec![
629 ChatMessage {
630 role: "user".to_string(),
631 content: Some("First".to_string()),
632 tool_calls: None,
633 },
634 ChatMessage {
635 role: "user".to_string(),
636 content: Some("Second".to_string()),
637 tool_calls: None,
638 },
639 ];
640 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
641 let result = transform_messages_for_capabilities(messages, caps);
642 assert_eq!(result.len(), 1);
643 assert_eq!(result[0].content, Some("First\n\nSecond".to_string()));
644 }
645
646 #[test]
647 fn test_transform_does_not_merge_tool_messages() {
648 let messages = vec![
649 ChatMessage {
650 role: "tool".to_string(),
651 content: Some("Result 1".to_string()),
652 tool_calls: None,
653 },
654 ChatMessage {
655 role: "tool".to_string(),
656 content: Some("Result 2".to_string()),
657 tool_calls: None,
658 },
659 ];
660 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
661 let result = transform_messages_for_capabilities(messages, caps);
662 assert_eq!(result.len(), 2); }
664
665 #[test]
666 fn test_transform_combined_system_and_merge() {
667 let messages = vec![
668 ChatMessage {
669 role: "system".to_string(),
670 content: Some("Be helpful".to_string()),
671 tool_calls: None,
672 },
673 ChatMessage {
674 role: "user".to_string(),
675 content: Some("First".to_string()),
676 tool_calls: None,
677 },
678 ChatMessage {
679 role: "user".to_string(),
680 content: Some("Second".to_string()),
681 tool_calls: None,
682 },
683 ];
684 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS; let result = transform_messages_for_capabilities(messages, caps);
686 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
688 assert!(
689 result[0]
690 .content
691 .as_ref()
692 .unwrap()
693 .contains("[System]: Be helpful")
694 );
695 assert!(result[0].content.as_ref().unwrap().contains("First"));
696 assert!(result[0].content.as_ref().unwrap().contains("Second"));
697 }
698}