1use bitflags::bitflags;
12use serde::{Deserialize, Serialize};
13
14bitflags! {
15 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20 #[repr(transparent)]
21 pub struct ModelCapabilities: u32 {
22 const SUPPORTS_SYSTEM_ROLE = 0b0000_0001;
27
28 const REQUIRES_STRICT_TURNS = 0b0000_0010;
33
34 const SUPPORTS_TOOL_CALLS = 0b0000_0100;
39
40 const SUPPORTS_REASONING = 0b0000_1000;
45 }
46}
47
48impl Default for ModelCapabilities {
49 fn default() -> Self {
54 Self::empty()
55 }
56}
57
58impl Serialize for ModelCapabilities {
59 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 S: serde::Serializer,
62 {
63 self.bits().serialize(serializer)
64 }
65}
66
67impl<'de> Deserialize<'de> for ModelCapabilities {
68 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69 where
70 D: serde::Deserializer<'de>,
71 {
72 let bits = u32::deserialize(deserializer)?;
73 Ok(Self::from_bits_truncate(bits))
74 }
75}
76
77impl ModelCapabilities {
78 pub const fn supports_system_role(self) -> bool {
80 self.contains(Self::SUPPORTS_SYSTEM_ROLE)
81 }
82
83 pub const fn requires_strict_turns(self) -> bool {
85 self.contains(Self::REQUIRES_STRICT_TURNS)
86 }
87
88 pub const fn supports_tool_calls(self) -> bool {
90 self.contains(Self::SUPPORTS_TOOL_CALLS)
91 }
92
93 pub const fn supports_reasoning(self) -> bool {
95 self.contains(Self::SUPPORTS_REASONING)
96 }
97}
98
99pub fn infer_from_chat_template(
123 template: Option<&str>,
124 model_name: Option<&str>,
125) -> ModelCapabilities {
126 let mut caps = ModelCapabilities::empty();
127
128 let mut tool_detected_from_metadata = false;
133 let mut reasoning_detected_from_metadata = false;
134
135 if let Some(template) = template {
136 let forbids_system = template.contains("Only user, assistant and tool roles are supported")
139 || template.contains("got system")
140 || template.contains("Raise exception for unsupported roles");
141
142 if forbids_system {
143 } else {
145 caps |= ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
146 }
147
148 let requires_alternation = template.contains("must alternate user and assistant")
151 || template.contains("conversation roles must alternate")
152 || template.contains("ns.index % 2");
153
154 if requires_alternation {
155 caps |= ModelCapabilities::REQUIRES_STRICT_TURNS;
156 }
157
158 let has_tool_patterns = template.contains("<tool_call>")
160 || template.contains("<|python_tag|>")
161 || template.contains("if tools")
162 || template.contains("tools is defined")
163 || template.contains("tool_calls")
164 || template.contains("function_call");
165
166 if has_tool_patterns {
167 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
168 tool_detected_from_metadata = true;
169 }
170
171 let has_reasoning_patterns = template.contains("<think>")
173 || template.contains("</think>")
174 || template.contains("<reasoning>")
175 || template.contains("</reasoning>")
176 || template.contains("enable_thinking")
177 || template.contains("thinking_forced_open")
178 || template.contains("reasoning_content");
179
180 if has_reasoning_patterns {
181 caps |= ModelCapabilities::SUPPORTS_REASONING;
182 reasoning_detected_from_metadata = true;
183 }
184 }
185
186 if let Some(name) = model_name {
194 let name_lower = name.to_lowercase();
195
196 if !tool_detected_from_metadata {
198 let has_tool_name = name_lower.contains("hermes")
199 || name_lower.contains("functionary")
200 || name_lower.contains("firefunction")
201 || name_lower.contains("gorilla");
202
203 if has_tool_name {
204 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
205 }
206 }
207
208 if !reasoning_detected_from_metadata {
210 let has_reasoning_name = name_lower.contains("deepseek-r1")
211 || name_lower.contains("qwq")
212 || name_lower.contains("-r1-")
213 || name_lower.contains("o1");
214
215 if has_reasoning_name {
216 caps |= ModelCapabilities::SUPPORTS_REASONING;
217 }
218 }
219 }
220
221 caps
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_default_capabilities() {
230 let caps = ModelCapabilities::default();
231 assert!(caps.is_empty());
233 assert!(!caps.supports_system_role());
234 assert!(!caps.requires_strict_turns());
235 assert!(!caps.supports_tool_calls());
236 assert!(!caps.supports_reasoning());
237 }
238
239 #[test]
240 fn test_infer_openai_style() {
241 let template = r"
242 {% for message in messages %}
243 {{ message.role }}: {{ message.content }}
244 {% endfor %}
245 ";
246 let caps = infer_from_chat_template(Some(template), None);
247 assert!(caps.supports_system_role());
248 assert!(!caps.requires_strict_turns());
249 }
250
251 #[test]
252 fn test_infer_mistral_style() {
253 let template = r"
254 {% if message.role == 'system' %}
255 {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
256 {% endif %}
257 {% if (message['role'] == 'user') != (ns.index % 2 == 0) %}
258 {{ raise_exception('conversation roles must alternate user and assistant') }}
259 {% endif %}
260 ";
261 let caps = infer_from_chat_template(Some(template), None);
262 assert!(!caps.supports_system_role());
263 assert!(caps.requires_strict_turns());
264 }
265
266 #[test]
267 fn test_infer_missing_template() {
268 let caps = infer_from_chat_template(None, None);
269 assert!(caps.is_empty());
271 assert!(!caps.supports_system_role());
272 }
273
274 #[test]
275 fn test_tool_calling_from_template() {
276 let template = r"
277 {% if tools %}
278 <tool_call>{{ message.tool_calls }}</tool_call>
279 {% endif %}
280 ";
281 let caps = infer_from_chat_template(Some(template), None);
282 assert!(caps.supports_tool_calls());
283 }
284
285 #[test]
286 fn test_reasoning_from_template() {
287 let template = r"
288 {% if enable_thinking %}
289 <think>{{ message.thinking }}</think>
290 {% endif %}
291 ";
292 let caps = infer_from_chat_template(Some(template), None);
293 assert!(caps.supports_reasoning());
294 }
295
296 #[test]
297 fn test_tool_calling_name_fallback() {
298 let caps = infer_from_chat_template(None, Some("hermes-2-pro-7b"));
300 assert!(caps.supports_tool_calls());
301 }
302
303 #[test]
304 fn test_reasoning_name_fallback() {
305 let caps = infer_from_chat_template(None, Some("deepseek-r1-lite"));
307 assert!(caps.supports_reasoning());
308 }
309
310 #[test]
311 fn test_metadata_plus_name_fallback() {
312 let template = "simple template with no tool markers";
314 let caps = infer_from_chat_template(Some(template), Some("hermes-model"));
315 assert!(caps.supports_tool_calls());
317 }
318
319 #[test]
320 fn test_metadata_detected_skips_name_fallback() {
321 let template = "<tool_call>detected</tool_call>";
323 let caps = infer_from_chat_template(Some(template), Some("not-a-tool-model"));
324 assert!(caps.supports_tool_calls());
326 }
327
328 #[test]
329 fn test_combined_detections() {
330 let template = r"
331 {% if tools %}<tool_call>{{ tool }}</tool_call>{% endif %}
332 <think>{{ reasoning }}</think>
333 ";
334 let caps = infer_from_chat_template(Some(template), None);
335 assert!(caps.supports_tool_calls());
336 assert!(caps.supports_reasoning());
337 }
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
346pub struct ChatMessage {
347 pub role: String,
348 pub content: Option<String>,
349 pub tool_calls: Option<serde_json::Value>,
350}
351
352fn merge_consecutive_system_messages(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
367 if messages.is_empty() {
368 return messages;
369 }
370
371 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
372
373 for msg in messages {
374 if let Some(last) = result.last_mut()
375 && last.role == "system"
376 && msg.role == "system"
377 {
378 let last_content = last.content.take().unwrap_or_default();
380 let new_content = msg.content.unwrap_or_default();
381
382 last.content = Some(if last_content.is_empty() {
383 new_content
384 } else if new_content.is_empty() {
385 last_content
386 } else {
387 format!("{last_content}\n\n{new_content}")
388 });
389
390 continue; }
392 result.push(msg);
393 }
394
395 result
396}
397
398pub fn transform_messages_for_capabilities(
422 mut messages: Vec<ChatMessage>,
423 capabilities: ModelCapabilities,
424) -> Vec<ChatMessage> {
425 messages = merge_consecutive_system_messages(messages);
430
431 if capabilities.is_empty() {
433 return messages;
434 }
435
436 if !capabilities.contains(ModelCapabilities::SUPPORTS_SYSTEM_ROLE) {
438 for msg in &mut messages {
439 if msg.role == "system" {
440 msg.role = "user".to_string();
441 if let Some(content) = &mut msg.content {
442 *content = format!("[System]: {content}");
443 }
444 }
445 }
446 }
447
448 if capabilities.contains(ModelCapabilities::REQUIRES_STRICT_TURNS) {
450 let mut merged_messages = Vec::new();
451 for msg in messages {
452 if let Some(last) = merged_messages.last_mut() {
453 let last_msg: &mut ChatMessage = last;
454 let is_mergeable_role = msg.role == "user" || msg.role == "assistant";
456
457 if last_msg.role == msg.role && is_mergeable_role {
459 match (&mut last_msg.content, &msg.content) {
461 (Some(last_content), Some(msg_content)) => {
462 last_content.push_str("\n\n");
464 last_content.push_str(msg_content);
465 }
466 (None, Some(msg_content)) => {
467 last_msg.content = Some(msg_content.clone());
469 }
470 _ => {}
473 }
474
475 match (&mut last_msg.tool_calls, &msg.tool_calls) {
477 (Some(last_calls), Some(msg_calls)) => {
478 if let (Some(last_arr), Some(msg_arr)) =
480 (last_calls.as_array_mut(), msg_calls.as_array())
481 {
482 last_arr.extend_from_slice(msg_arr);
483 }
484 }
485 (None, Some(msg_calls)) => {
486 last_msg.tool_calls = Some(msg_calls.clone());
488 }
489 _ => {}
492 }
493
494 continue; }
496 }
497 merged_messages.push(msg);
498 }
499 return merged_messages;
500 }
501
502 messages
503}
504
505#[cfg(test)]
506mod transform_tests {
507 use super::*;
508
509 #[test]
510 fn test_transform_unknown_passes_through_non_system() {
511 let messages = vec![
513 ChatMessage {
514 role: "user".to_string(),
515 content: Some("Hello".to_string()),
516 tool_calls: None,
517 },
518 ChatMessage {
519 role: "assistant".to_string(),
520 content: Some("Hi there".to_string()),
521 tool_calls: None,
522 },
523 ];
524 let original = messages.clone();
525 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
526 assert_eq!(result, original);
527 }
528
529 #[test]
530 fn test_merges_consecutive_system_messages_always() {
531 let messages = vec![
533 ChatMessage {
534 role: "system".to_string(),
535 content: Some("You are a helpful assistant.".to_string()),
536 tool_calls: None,
537 },
538 ChatMessage {
539 role: "system".to_string(),
540 content: Some("WORKING_MEMORY:\n- task1 (ok): done".to_string()),
541 tool_calls: None,
542 },
543 ChatMessage {
544 role: "user".to_string(),
545 content: Some("Hello".to_string()),
546 tool_calls: None,
547 },
548 ];
549 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
550
551 assert_eq!(result.len(), 2);
552 assert_eq!(result[0].role, "system");
553 assert_eq!(
554 result[0].content.as_deref(),
555 Some("You are a helpful assistant.\n\nWORKING_MEMORY:\n- task1 (ok): done")
556 );
557 assert_eq!(result[1].role, "user");
558 }
559
560 #[test]
561 fn test_merges_three_consecutive_system_messages() {
562 let messages = vec![
563 ChatMessage {
564 role: "system".to_string(),
565 content: Some("First.".to_string()),
566 tool_calls: None,
567 },
568 ChatMessage {
569 role: "system".to_string(),
570 content: Some("Second.".to_string()),
571 tool_calls: None,
572 },
573 ChatMessage {
574 role: "system".to_string(),
575 content: Some("Third.".to_string()),
576 tool_calls: None,
577 },
578 ];
579 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
580
581 assert_eq!(result.len(), 1);
582 assert_eq!(
583 result[0].content.as_deref(),
584 Some("First.\n\nSecond.\n\nThird.")
585 );
586 }
587
588 #[test]
589 fn test_handles_empty_system_content() {
590 let messages = vec![
591 ChatMessage {
592 role: "system".to_string(),
593 content: Some(String::new()),
594 tool_calls: None,
595 },
596 ChatMessage {
597 role: "system".to_string(),
598 content: Some("Actual content".to_string()),
599 tool_calls: None,
600 },
601 ];
602 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
603
604 assert_eq!(result.len(), 1);
605 assert_eq!(result[0].content.as_deref(), Some("Actual content"));
606 }
607
608 #[test]
609 fn test_transform_system_to_user() {
610 let messages = vec![
611 ChatMessage {
612 role: "system".to_string(),
613 content: Some("You are helpful".to_string()),
614 tool_calls: None,
615 },
616 ChatMessage {
617 role: "user".to_string(),
618 content: Some("Hello".to_string()),
619 tool_calls: None,
620 },
621 ];
622 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
624 let result = transform_messages_for_capabilities(messages, caps);
625 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
628 assert!(
629 result[0]
630 .content
631 .as_ref()
632 .unwrap()
633 .contains("[System]: You are helpful")
634 );
635 assert!(result[0].content.as_ref().unwrap().contains("Hello"));
636 }
637
638 #[test]
639 fn test_transform_preserves_system_when_supported() {
640 let messages = vec![ChatMessage {
641 role: "system".to_string(),
642 content: Some("You are helpful".to_string()),
643 tool_calls: None,
644 }];
645 let caps = ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
646 let result = transform_messages_for_capabilities(messages, caps);
647 assert_eq!(result[0].role, "system");
648 assert_eq!(result[0].content, Some("You are helpful".to_string()));
649 }
650
651 #[test]
652 fn test_transform_merges_consecutive_user_messages() {
653 let messages = vec![
654 ChatMessage {
655 role: "user".to_string(),
656 content: Some("First".to_string()),
657 tool_calls: None,
658 },
659 ChatMessage {
660 role: "user".to_string(),
661 content: Some("Second".to_string()),
662 tool_calls: None,
663 },
664 ];
665 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
666 let result = transform_messages_for_capabilities(messages, caps);
667 assert_eq!(result.len(), 1);
668 assert_eq!(result[0].content, Some("First\n\nSecond".to_string()));
669 }
670
671 #[test]
672 fn test_transform_does_not_merge_tool_messages() {
673 let messages = vec![
674 ChatMessage {
675 role: "tool".to_string(),
676 content: Some("Result 1".to_string()),
677 tool_calls: None,
678 },
679 ChatMessage {
680 role: "tool".to_string(),
681 content: Some("Result 2".to_string()),
682 tool_calls: None,
683 },
684 ];
685 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
686 let result = transform_messages_for_capabilities(messages, caps);
687 assert_eq!(result.len(), 2); }
689
690 #[test]
691 fn test_transform_combined_system_and_merge() {
692 let messages = vec![
693 ChatMessage {
694 role: "system".to_string(),
695 content: Some("Be helpful".to_string()),
696 tool_calls: None,
697 },
698 ChatMessage {
699 role: "user".to_string(),
700 content: Some("First".to_string()),
701 tool_calls: None,
702 },
703 ChatMessage {
704 role: "user".to_string(),
705 content: Some("Second".to_string()),
706 tool_calls: None,
707 },
708 ];
709 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS; let result = transform_messages_for_capabilities(messages, caps);
711 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
713 assert!(
714 result[0]
715 .content
716 .as_ref()
717 .unwrap()
718 .contains("[System]: Be helpful")
719 );
720 assert!(result[0].content.as_ref().unwrap().contains("First"));
721 assert!(result[0].content.as_ref().unwrap().contains("Second"));
722 }
723
724 #[test]
725 fn test_merge_consecutive_assistant_with_tool_calls() {
726 let tool_call_1 = serde_json::json!([
729 {
730 "id": "call_1",
731 "type": "function",
732 "function": {
733 "name": "get_weather",
734 "arguments": "{\"location\":\"Paris\"}"
735 }
736 }
737 ]);
738 let tool_call_2 = serde_json::json!([
739 {
740 "id": "call_2",
741 "type": "function",
742 "function": {
743 "name": "get_time",
744 "arguments": "{\"timezone\":\"UTC\"}"
745 }
746 }
747 ]);
748
749 let messages = vec![
750 ChatMessage {
751 role: "user".to_string(),
752 content: Some("What's the weather?".to_string()),
753 tool_calls: None,
754 },
755 ChatMessage {
756 role: "assistant".to_string(),
757 content: Some("Let me check...".to_string()),
758 tool_calls: Some(tool_call_1),
759 },
760 ChatMessage {
761 role: "assistant".to_string(),
762 content: Some("And the time...".to_string()),
763 tool_calls: Some(tool_call_2),
764 },
765 ];
766
767 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
768 let result = transform_messages_for_capabilities(messages, caps);
769
770 assert_eq!(result.len(), 2);
772 assert_eq!(result[0].role, "user");
773 assert_eq!(result[1].role, "assistant");
774
775 assert_eq!(
777 result[1].content,
778 Some("Let me check...\n\nAnd the time...".to_string())
779 );
780
781 let merged_tool_calls = result[1].tool_calls.as_ref().unwrap();
783 let tool_calls_array = merged_tool_calls.as_array().unwrap();
784 assert_eq!(tool_calls_array.len(), 2);
785 assert_eq!(tool_calls_array[0]["id"], "call_1");
786 assert_eq!(tool_calls_array[1]["id"], "call_2");
787 }
788
789 #[test]
790 fn test_merge_assistant_messages_only_first_has_content() {
791 let tool_call = serde_json::json!([
793 {
794 "id": "call_1",
795 "type": "function",
796 "function": {
797 "name": "get_weather",
798 "arguments": "{}"
799 }
800 }
801 ]);
802
803 let messages = vec![
804 ChatMessage {
805 role: "assistant".to_string(),
806 content: Some("Let me check...".to_string()),
807 tool_calls: None,
808 },
809 ChatMessage {
810 role: "assistant".to_string(),
811 content: None,
812 tool_calls: Some(tool_call),
813 },
814 ];
815
816 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
817 let result = transform_messages_for_capabilities(messages, caps);
818
819 assert_eq!(result.len(), 1);
820 assert_eq!(result[0].content, Some("Let me check...".to_string()));
821 assert!(result[0].tool_calls.is_some());
822 }
823
824 #[test]
825 fn test_merge_assistant_messages_only_second_has_content() {
826 let tool_call = serde_json::json!([
828 {
829 "id": "call_1",
830 "type": "function",
831 "function": {
832 "name": "get_weather",
833 "arguments": "{}"
834 }
835 }
836 ]);
837
838 let messages = vec![
839 ChatMessage {
840 role: "assistant".to_string(),
841 content: None,
842 tool_calls: Some(tool_call),
843 },
844 ChatMessage {
845 role: "assistant".to_string(),
846 content: Some("Result received".to_string()),
847 tool_calls: None,
848 },
849 ];
850
851 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
852 let result = transform_messages_for_capabilities(messages, caps);
853
854 assert_eq!(result.len(), 1);
855 assert_eq!(result[0].content, Some("Result received".to_string()));
856 assert!(result[0].tool_calls.is_some());
857 }
858
859 #[test]
860 fn test_merge_assistant_messages_neither_has_content() {
861 let tool_call_1 = serde_json::json!([
863 {
864 "id": "call_1",
865 "type": "function",
866 "function": {"name": "tool1", "arguments": "{}"}
867 }
868 ]);
869 let tool_call_2 = serde_json::json!([
870 {
871 "id": "call_2",
872 "type": "function",
873 "function": {"name": "tool2", "arguments": "{}"}
874 }
875 ]);
876
877 let messages = vec![
878 ChatMessage {
879 role: "assistant".to_string(),
880 content: None,
881 tool_calls: Some(tool_call_1),
882 },
883 ChatMessage {
884 role: "assistant".to_string(),
885 content: None,
886 tool_calls: Some(tool_call_2),
887 },
888 ];
889
890 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
891 let result = transform_messages_for_capabilities(messages, caps);
892
893 assert_eq!(result.len(), 1);
894 assert!(result[0].content.is_none());
895
896 let merged_tool_calls = result[0].tool_calls.as_ref().unwrap();
897 let tool_calls_array = merged_tool_calls.as_array().unwrap();
898 assert_eq!(tool_calls_array.len(), 2);
899 }
900
901 #[test]
902 fn test_no_merge_without_strict_turns_capability() {
903 let messages = vec![
905 ChatMessage {
906 role: "assistant".to_string(),
907 content: Some("First".to_string()),
908 tool_calls: None,
909 },
910 ChatMessage {
911 role: "assistant".to_string(),
912 content: Some("Second".to_string()),
913 tool_calls: None,
914 },
915 ];
916
917 let caps = ModelCapabilities::empty();
918 let result = transform_messages_for_capabilities(messages, caps);
919
920 assert_eq!(result.len(), 2);
922 }
923
924 #[test]
925 fn test_merge_preserves_different_role_boundaries() {
926 let tool_call = serde_json::json!([
928 {
929 "id": "call_1",
930 "type": "function",
931 "function": {"name": "tool1", "arguments": "{}"}
932 }
933 ]);
934
935 let messages = vec![
936 ChatMessage {
937 role: "user".to_string(),
938 content: Some("Question".to_string()),
939 tool_calls: None,
940 },
941 ChatMessage {
942 role: "assistant".to_string(),
943 content: Some("Answer".to_string()),
944 tool_calls: Some(tool_call),
945 },
946 ChatMessage {
947 role: "user".to_string(),
948 content: Some("Follow-up".to_string()),
949 tool_calls: None,
950 },
951 ];
952
953 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
954 let result = transform_messages_for_capabilities(messages, caps);
955
956 assert_eq!(result.len(), 3);
958 assert_eq!(result[0].role, "user");
959 assert_eq!(result[1].role, "assistant");
960 assert_eq!(result[2].role, "user");
961 }
962}