1use bitflags::bitflags;
12use serde::{Deserialize, Serialize};
13
14bitflags! {
15 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20 #[repr(transparent)]
21 pub struct ModelCapabilities: u32 {
22 const SUPPORTS_SYSTEM_ROLE = 0b0000_0001;
27
28 const REQUIRES_STRICT_TURNS = 0b0000_0010;
33
34 const SUPPORTS_TOOL_CALLS = 0b0000_0100;
39
40 const SUPPORTS_REASONING = 0b0000_1000;
45 }
46}
47
48impl Default for ModelCapabilities {
49 fn default() -> Self {
54 Self::empty()
55 }
56}
57
58impl Serialize for ModelCapabilities {
59 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 S: serde::Serializer,
62 {
63 self.bits().serialize(serializer)
64 }
65}
66
67impl<'de> Deserialize<'de> for ModelCapabilities {
68 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
69 where
70 D: serde::Deserializer<'de>,
71 {
72 let bits = u32::deserialize(deserializer)?;
73 Ok(Self::from_bits_truncate(bits))
74 }
75}
76
77impl ModelCapabilities {
78 pub const fn supports_system_role(self) -> bool {
80 self.contains(Self::SUPPORTS_SYSTEM_ROLE)
81 }
82
83 pub const fn requires_strict_turns(self) -> bool {
85 self.contains(Self::REQUIRES_STRICT_TURNS)
86 }
87
88 pub const fn supports_tool_calls(self) -> bool {
90 self.contains(Self::SUPPORTS_TOOL_CALLS)
91 }
92
93 pub const fn supports_reasoning(self) -> bool {
95 self.contains(Self::SUPPORTS_REASONING)
96 }
97}
98
99pub fn infer_from_chat_template(
123 template: Option<&str>,
124 model_name: Option<&str>,
125) -> ModelCapabilities {
126 let mut caps = ModelCapabilities::empty();
127
128 let mut tool_detected_from_metadata = false;
133 let mut reasoning_detected_from_metadata = false;
134
135 if let Some(template) = template {
136 let forbids_system = template.contains("Only user, assistant and tool roles are supported")
139 || template.contains("got system")
140 || template.contains("Raise exception for unsupported roles");
141
142 if forbids_system {
143 } else {
145 caps |= ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
146 }
147
148 let requires_alternation = template.contains("must alternate user and assistant")
151 || template.contains("conversation roles must alternate")
152 || template.contains("ns.index % 2");
153
154 if requires_alternation {
155 caps |= ModelCapabilities::REQUIRES_STRICT_TURNS;
156 }
157
158 let has_tool_patterns = template.contains("<tool_call>")
160 || template.contains("<|python_tag|>")
161 || template.contains("if tools")
162 || template.contains("tools is defined")
163 || template.contains("tool_calls")
164 || template.contains("function_call");
165
166 if has_tool_patterns {
167 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
168 tool_detected_from_metadata = true;
169 }
170
171 let has_reasoning_patterns = template.contains("<think>")
173 || template.contains("</think>")
174 || template.contains("<reasoning>")
175 || template.contains("</reasoning>")
176 || template.contains("enable_thinking")
177 || template.contains("thinking_forced_open")
178 || template.contains("reasoning_content");
179
180 if has_reasoning_patterns {
181 caps |= ModelCapabilities::SUPPORTS_REASONING;
182 reasoning_detected_from_metadata = true;
183 }
184 }
185
186 if let Some(name) = model_name {
194 let name_lower = name.to_lowercase();
195
196 if !tool_detected_from_metadata {
198 let has_tool_name = name_lower.contains("hermes")
199 || name_lower.contains("functionary")
200 || name_lower.contains("firefunction")
201 || name_lower.contains("gorilla");
202
203 if has_tool_name {
204 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
205 }
206 }
207
208 if !reasoning_detected_from_metadata {
210 let has_reasoning_name = name_lower.contains("deepseek-r1")
211 || name_lower.contains("qwq")
212 || name_lower.contains("-r1-")
213 || name_lower.contains("o1");
214
215 if has_reasoning_name {
216 caps |= ModelCapabilities::SUPPORTS_REASONING;
217 }
218 }
219 }
220
221 caps
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_default_capabilities() {
230 let caps = ModelCapabilities::default();
231 assert!(caps.is_empty());
233 assert!(!caps.supports_system_role());
234 assert!(!caps.requires_strict_turns());
235 assert!(!caps.supports_tool_calls());
236 assert!(!caps.supports_reasoning());
237 }
238
239 #[test]
240 fn test_infer_openai_style() {
241 let template = r"
242 {% for message in messages %}
243 {{ message.role }}: {{ message.content }}
244 {% endfor %}
245 ";
246 let caps = infer_from_chat_template(Some(template), None);
247 assert!(caps.supports_system_role());
248 assert!(!caps.requires_strict_turns());
249 }
250
251 #[test]
252 fn test_infer_mistral_style() {
253 let template = r"
254 {% if message.role == 'system' %}
255 {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
256 {% endif %}
257 {% if (message['role'] == 'user') != (ns.index % 2 == 0) %}
258 {{ raise_exception('conversation roles must alternate user and assistant') }}
259 {% endif %}
260 ";
261 let caps = infer_from_chat_template(Some(template), None);
262 assert!(!caps.supports_system_role());
263 assert!(caps.requires_strict_turns());
264 }
265
266 #[test]
267 fn test_infer_missing_template() {
268 let caps = infer_from_chat_template(None, None);
269 assert!(caps.is_empty());
271 assert!(!caps.supports_system_role());
272 }
273
274 #[test]
275 fn test_tool_calling_from_template() {
276 let template = r"
277 {% if tools %}
278 <tool_call>{{ message.tool_calls }}</tool_call>
279 {% endif %}
280 ";
281 let caps = infer_from_chat_template(Some(template), None);
282 assert!(caps.supports_tool_calls());
283 }
284
285 #[test]
286 fn test_reasoning_from_template() {
287 let template = r"
288 {% if enable_thinking %}
289 <think>{{ message.thinking }}</think>
290 {% endif %}
291 ";
292 let caps = infer_from_chat_template(Some(template), None);
293 assert!(caps.supports_reasoning());
294 }
295
296 #[test]
297 fn test_tool_calling_name_fallback() {
298 let caps = infer_from_chat_template(None, Some("hermes-2-pro-7b"));
300 assert!(caps.supports_tool_calls());
301 }
302
303 #[test]
304 fn test_reasoning_name_fallback() {
305 let caps = infer_from_chat_template(None, Some("deepseek-r1-lite"));
307 assert!(caps.supports_reasoning());
308 }
309
310 #[test]
311 fn test_metadata_plus_name_fallback() {
312 let template = "simple template with no tool markers";
314 let caps = infer_from_chat_template(Some(template), Some("hermes-model"));
315 assert!(caps.supports_tool_calls());
317 }
318
319 #[test]
320 fn test_metadata_detected_skips_name_fallback() {
321 let template = "<tool_call>detected</tool_call>";
323 let caps = infer_from_chat_template(Some(template), Some("not-a-tool-model"));
324 assert!(caps.supports_tool_calls());
326 }
327
328 #[test]
329 fn test_combined_detections() {
330 let template = r"
331 {% if tools %}<tool_call>{{ tool }}</tool_call>{% endif %}
332 <think>{{ reasoning }}</think>
333 ";
334 let caps = infer_from_chat_template(Some(template), None);
335 assert!(caps.supports_tool_calls());
336 assert!(caps.supports_reasoning());
337 }
338}
339
340#[derive(Debug, Clone, PartialEq, Eq)]
346pub struct ChatMessage {
347 pub role: String,
348 pub content: Option<String>,
349 pub tool_calls: Option<serde_json::Value>,
350}
351
352fn merge_consecutive_system_messages(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
367 if messages.is_empty() {
368 return messages;
369 }
370
371 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
372
373 for msg in messages {
374 if let Some(last) = result.last_mut() {
375 if last.role == "system" && msg.role == "system" {
376 let last_content = last.content.take().unwrap_or_default();
378 let new_content = msg.content.unwrap_or_default();
379
380 last.content = Some(if last_content.is_empty() {
381 new_content
382 } else if new_content.is_empty() {
383 last_content
384 } else {
385 format!("{last_content}\n\n{new_content}")
386 });
387
388 continue; }
390 }
391 result.push(msg);
392 }
393
394 result
395}
396
397pub fn transform_messages_for_capabilities(
421 mut messages: Vec<ChatMessage>,
422 capabilities: ModelCapabilities,
423) -> Vec<ChatMessage> {
424 messages = merge_consecutive_system_messages(messages);
429
430 if capabilities.is_empty() {
432 return messages;
433 }
434
435 if !capabilities.contains(ModelCapabilities::SUPPORTS_SYSTEM_ROLE) {
437 for msg in &mut messages {
438 if msg.role == "system" {
439 msg.role = "user".to_string();
440 if let Some(content) = &mut msg.content {
441 *content = format!("[System]: {content}");
442 }
443 }
444 }
445 }
446
447 if capabilities.contains(ModelCapabilities::REQUIRES_STRICT_TURNS) {
449 let mut merged_messages = Vec::new();
450 for msg in messages {
451 if let Some(last) = merged_messages.last_mut() {
452 let last_msg: &mut ChatMessage = last;
453 let is_mergeable_role = msg.role == "user" || msg.role == "assistant";
455
456 if last_msg.role == msg.role && is_mergeable_role {
458 match (&mut last_msg.content, &msg.content) {
460 (Some(last_content), Some(msg_content)) => {
461 last_content.push_str("\n\n");
463 last_content.push_str(msg_content);
464 }
465 (None, Some(msg_content)) => {
466 last_msg.content = Some(msg_content.clone());
468 }
469 _ => {}
472 }
473
474 match (&mut last_msg.tool_calls, &msg.tool_calls) {
476 (Some(last_calls), Some(msg_calls)) => {
477 if let (Some(last_arr), Some(msg_arr)) =
479 (last_calls.as_array_mut(), msg_calls.as_array())
480 {
481 last_arr.extend_from_slice(msg_arr);
482 }
483 }
484 (None, Some(msg_calls)) => {
485 last_msg.tool_calls = Some(msg_calls.clone());
487 }
488 _ => {}
491 }
492
493 continue; }
495 }
496 merged_messages.push(msg);
497 }
498 return merged_messages;
499 }
500
501 messages
502}
503
504#[cfg(test)]
505mod transform_tests {
506 use super::*;
507
508 #[test]
509 fn test_transform_unknown_passes_through_non_system() {
510 let messages = vec![
512 ChatMessage {
513 role: "user".to_string(),
514 content: Some("Hello".to_string()),
515 tool_calls: None,
516 },
517 ChatMessage {
518 role: "assistant".to_string(),
519 content: Some("Hi there".to_string()),
520 tool_calls: None,
521 },
522 ];
523 let original = messages.clone();
524 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
525 assert_eq!(result, original);
526 }
527
528 #[test]
529 fn test_merges_consecutive_system_messages_always() {
530 let messages = vec![
532 ChatMessage {
533 role: "system".to_string(),
534 content: Some("You are a helpful assistant.".to_string()),
535 tool_calls: None,
536 },
537 ChatMessage {
538 role: "system".to_string(),
539 content: Some("WORKING_MEMORY:\n- task1 (ok): done".to_string()),
540 tool_calls: None,
541 },
542 ChatMessage {
543 role: "user".to_string(),
544 content: Some("Hello".to_string()),
545 tool_calls: None,
546 },
547 ];
548 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
549
550 assert_eq!(result.len(), 2);
551 assert_eq!(result[0].role, "system");
552 assert_eq!(
553 result[0].content.as_deref(),
554 Some("You are a helpful assistant.\n\nWORKING_MEMORY:\n- task1 (ok): done")
555 );
556 assert_eq!(result[1].role, "user");
557 }
558
559 #[test]
560 fn test_merges_three_consecutive_system_messages() {
561 let messages = vec![
562 ChatMessage {
563 role: "system".to_string(),
564 content: Some("First.".to_string()),
565 tool_calls: None,
566 },
567 ChatMessage {
568 role: "system".to_string(),
569 content: Some("Second.".to_string()),
570 tool_calls: None,
571 },
572 ChatMessage {
573 role: "system".to_string(),
574 content: Some("Third.".to_string()),
575 tool_calls: None,
576 },
577 ];
578 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
579
580 assert_eq!(result.len(), 1);
581 assert_eq!(
582 result[0].content.as_deref(),
583 Some("First.\n\nSecond.\n\nThird.")
584 );
585 }
586
587 #[test]
588 fn test_handles_empty_system_content() {
589 let messages = vec![
590 ChatMessage {
591 role: "system".to_string(),
592 content: Some(String::new()),
593 tool_calls: None,
594 },
595 ChatMessage {
596 role: "system".to_string(),
597 content: Some("Actual content".to_string()),
598 tool_calls: None,
599 },
600 ];
601 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
602
603 assert_eq!(result.len(), 1);
604 assert_eq!(result[0].content.as_deref(), Some("Actual content"));
605 }
606
607 #[test]
608 fn test_transform_system_to_user() {
609 let messages = vec![
610 ChatMessage {
611 role: "system".to_string(),
612 content: Some("You are helpful".to_string()),
613 tool_calls: None,
614 },
615 ChatMessage {
616 role: "user".to_string(),
617 content: Some("Hello".to_string()),
618 tool_calls: None,
619 },
620 ];
621 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
623 let result = transform_messages_for_capabilities(messages, caps);
624 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
627 assert!(
628 result[0]
629 .content
630 .as_ref()
631 .unwrap()
632 .contains("[System]: You are helpful")
633 );
634 assert!(result[0].content.as_ref().unwrap().contains("Hello"));
635 }
636
637 #[test]
638 fn test_transform_preserves_system_when_supported() {
639 let messages = vec![ChatMessage {
640 role: "system".to_string(),
641 content: Some("You are helpful".to_string()),
642 tool_calls: None,
643 }];
644 let caps = ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
645 let result = transform_messages_for_capabilities(messages, caps);
646 assert_eq!(result[0].role, "system");
647 assert_eq!(result[0].content, Some("You are helpful".to_string()));
648 }
649
650 #[test]
651 fn test_transform_merges_consecutive_user_messages() {
652 let messages = vec![
653 ChatMessage {
654 role: "user".to_string(),
655 content: Some("First".to_string()),
656 tool_calls: None,
657 },
658 ChatMessage {
659 role: "user".to_string(),
660 content: Some("Second".to_string()),
661 tool_calls: None,
662 },
663 ];
664 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
665 let result = transform_messages_for_capabilities(messages, caps);
666 assert_eq!(result.len(), 1);
667 assert_eq!(result[0].content, Some("First\n\nSecond".to_string()));
668 }
669
670 #[test]
671 fn test_transform_does_not_merge_tool_messages() {
672 let messages = vec![
673 ChatMessage {
674 role: "tool".to_string(),
675 content: Some("Result 1".to_string()),
676 tool_calls: None,
677 },
678 ChatMessage {
679 role: "tool".to_string(),
680 content: Some("Result 2".to_string()),
681 tool_calls: None,
682 },
683 ];
684 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
685 let result = transform_messages_for_capabilities(messages, caps);
686 assert_eq!(result.len(), 2); }
688
689 #[test]
690 fn test_transform_combined_system_and_merge() {
691 let messages = vec![
692 ChatMessage {
693 role: "system".to_string(),
694 content: Some("Be helpful".to_string()),
695 tool_calls: None,
696 },
697 ChatMessage {
698 role: "user".to_string(),
699 content: Some("First".to_string()),
700 tool_calls: None,
701 },
702 ChatMessage {
703 role: "user".to_string(),
704 content: Some("Second".to_string()),
705 tool_calls: None,
706 },
707 ];
708 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS; let result = transform_messages_for_capabilities(messages, caps);
710 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
712 assert!(
713 result[0]
714 .content
715 .as_ref()
716 .unwrap()
717 .contains("[System]: Be helpful")
718 );
719 assert!(result[0].content.as_ref().unwrap().contains("First"));
720 assert!(result[0].content.as_ref().unwrap().contains("Second"));
721 }
722
723 #[test]
724 fn test_merge_consecutive_assistant_with_tool_calls() {
725 let tool_call_1 = serde_json::json!([
728 {
729 "id": "call_1",
730 "type": "function",
731 "function": {
732 "name": "get_weather",
733 "arguments": "{\"location\":\"Paris\"}"
734 }
735 }
736 ]);
737 let tool_call_2 = serde_json::json!([
738 {
739 "id": "call_2",
740 "type": "function",
741 "function": {
742 "name": "get_time",
743 "arguments": "{\"timezone\":\"UTC\"}"
744 }
745 }
746 ]);
747
748 let messages = vec![
749 ChatMessage {
750 role: "user".to_string(),
751 content: Some("What's the weather?".to_string()),
752 tool_calls: None,
753 },
754 ChatMessage {
755 role: "assistant".to_string(),
756 content: Some("Let me check...".to_string()),
757 tool_calls: Some(tool_call_1),
758 },
759 ChatMessage {
760 role: "assistant".to_string(),
761 content: Some("And the time...".to_string()),
762 tool_calls: Some(tool_call_2),
763 },
764 ];
765
766 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
767 let result = transform_messages_for_capabilities(messages, caps);
768
769 assert_eq!(result.len(), 2);
771 assert_eq!(result[0].role, "user");
772 assert_eq!(result[1].role, "assistant");
773
774 assert_eq!(
776 result[1].content,
777 Some("Let me check...\n\nAnd the time...".to_string())
778 );
779
780 let merged_tool_calls = result[1].tool_calls.as_ref().unwrap();
782 let tool_calls_array = merged_tool_calls.as_array().unwrap();
783 assert_eq!(tool_calls_array.len(), 2);
784 assert_eq!(tool_calls_array[0]["id"], "call_1");
785 assert_eq!(tool_calls_array[1]["id"], "call_2");
786 }
787
788 #[test]
789 fn test_merge_assistant_messages_only_first_has_content() {
790 let tool_call = serde_json::json!([
792 {
793 "id": "call_1",
794 "type": "function",
795 "function": {
796 "name": "get_weather",
797 "arguments": "{}"
798 }
799 }
800 ]);
801
802 let messages = vec![
803 ChatMessage {
804 role: "assistant".to_string(),
805 content: Some("Let me check...".to_string()),
806 tool_calls: None,
807 },
808 ChatMessage {
809 role: "assistant".to_string(),
810 content: None,
811 tool_calls: Some(tool_call),
812 },
813 ];
814
815 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
816 let result = transform_messages_for_capabilities(messages, caps);
817
818 assert_eq!(result.len(), 1);
819 assert_eq!(result[0].content, Some("Let me check...".to_string()));
820 assert!(result[0].tool_calls.is_some());
821 }
822
823 #[test]
824 fn test_merge_assistant_messages_only_second_has_content() {
825 let tool_call = serde_json::json!([
827 {
828 "id": "call_1",
829 "type": "function",
830 "function": {
831 "name": "get_weather",
832 "arguments": "{}"
833 }
834 }
835 ]);
836
837 let messages = vec![
838 ChatMessage {
839 role: "assistant".to_string(),
840 content: None,
841 tool_calls: Some(tool_call),
842 },
843 ChatMessage {
844 role: "assistant".to_string(),
845 content: Some("Result received".to_string()),
846 tool_calls: None,
847 },
848 ];
849
850 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
851 let result = transform_messages_for_capabilities(messages, caps);
852
853 assert_eq!(result.len(), 1);
854 assert_eq!(result[0].content, Some("Result received".to_string()));
855 assert!(result[0].tool_calls.is_some());
856 }
857
858 #[test]
859 fn test_merge_assistant_messages_neither_has_content() {
860 let tool_call_1 = serde_json::json!([
862 {
863 "id": "call_1",
864 "type": "function",
865 "function": {"name": "tool1", "arguments": "{}"}
866 }
867 ]);
868 let tool_call_2 = serde_json::json!([
869 {
870 "id": "call_2",
871 "type": "function",
872 "function": {"name": "tool2", "arguments": "{}"}
873 }
874 ]);
875
876 let messages = vec![
877 ChatMessage {
878 role: "assistant".to_string(),
879 content: None,
880 tool_calls: Some(tool_call_1),
881 },
882 ChatMessage {
883 role: "assistant".to_string(),
884 content: None,
885 tool_calls: Some(tool_call_2),
886 },
887 ];
888
889 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
890 let result = transform_messages_for_capabilities(messages, caps);
891
892 assert_eq!(result.len(), 1);
893 assert!(result[0].content.is_none());
894
895 let merged_tool_calls = result[0].tool_calls.as_ref().unwrap();
896 let tool_calls_array = merged_tool_calls.as_array().unwrap();
897 assert_eq!(tool_calls_array.len(), 2);
898 }
899
900 #[test]
901 fn test_no_merge_without_strict_turns_capability() {
902 let messages = vec![
904 ChatMessage {
905 role: "assistant".to_string(),
906 content: Some("First".to_string()),
907 tool_calls: None,
908 },
909 ChatMessage {
910 role: "assistant".to_string(),
911 content: Some("Second".to_string()),
912 tool_calls: None,
913 },
914 ];
915
916 let caps = ModelCapabilities::empty();
917 let result = transform_messages_for_capabilities(messages, caps);
918
919 assert_eq!(result.len(), 2);
921 }
922
923 #[test]
924 fn test_merge_preserves_different_role_boundaries() {
925 let tool_call = serde_json::json!([
927 {
928 "id": "call_1",
929 "type": "function",
930 "function": {"name": "tool1", "arguments": "{}"}
931 }
932 ]);
933
934 let messages = vec![
935 ChatMessage {
936 role: "user".to_string(),
937 content: Some("Question".to_string()),
938 tool_calls: None,
939 },
940 ChatMessage {
941 role: "assistant".to_string(),
942 content: Some("Answer".to_string()),
943 tool_calls: Some(tool_call),
944 },
945 ChatMessage {
946 role: "user".to_string(),
947 content: Some("Follow-up".to_string()),
948 tool_calls: None,
949 },
950 ];
951
952 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
953 let result = transform_messages_for_capabilities(messages, caps);
954
955 assert_eq!(result.len(), 3);
957 assert_eq!(result[0].role, "user");
958 assert_eq!(result[1].role, "assistant");
959 assert_eq!(result[2].role, "user");
960 }
961}