1use bitflags::bitflags;
76use serde::{Deserialize, Serialize};
77
78bitflags! {
79 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
84 #[repr(transparent)]
85 pub struct ModelCapabilities: u32 {
86 const SUPPORTS_SYSTEM_ROLE = 0b0000_0001;
91
92 const REQUIRES_STRICT_TURNS = 0b0000_0010;
97
98 const SUPPORTS_TOOL_CALLS = 0b0000_0100;
103
104 const SUPPORTS_REASONING = 0b0000_1000;
109 }
110}
111
112impl Default for ModelCapabilities {
113 fn default() -> Self {
118 Self::empty()
119 }
120}
121
122impl Serialize for ModelCapabilities {
123 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124 where
125 S: serde::Serializer,
126 {
127 self.bits().serialize(serializer)
128 }
129}
130
131impl<'de> Deserialize<'de> for ModelCapabilities {
132 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
133 where
134 D: serde::Deserializer<'de>,
135 {
136 let bits = u32::deserialize(deserializer)?;
137 Ok(Self::from_bits_truncate(bits))
138 }
139}
140
141impl ModelCapabilities {
142 pub const fn supports_system_role(self) -> bool {
144 self.contains(Self::SUPPORTS_SYSTEM_ROLE)
145 }
146
147 pub const fn requires_strict_turns(self) -> bool {
149 self.contains(Self::REQUIRES_STRICT_TURNS)
150 }
151
152 pub const fn supports_tool_calls(self) -> bool {
154 self.contains(Self::SUPPORTS_TOOL_CALLS)
155 }
156
157 pub const fn supports_reasoning(self) -> bool {
159 self.contains(Self::SUPPORTS_REASONING)
160 }
161}
162
163pub fn infer_from_chat_template(
190 template: Option<&str>,
191 model_name: Option<&str>,
192) -> ModelCapabilities {
193 let mut caps = ModelCapabilities::empty();
194
195 let mut tool_detected_from_metadata = false;
200 let mut reasoning_detected_from_metadata = false;
201
202 if let Some(template) = template {
203 let supports_system_positive =
215 template.contains("[SYSTEM_PROMPT]") || template.contains("[AVAILABLE_TOOLS]");
216
217 let forbids_system = !supports_system_positive
218 && (template.contains("Only user, assistant and tool roles are supported")
219 || template.contains("got system")
220 || template.contains("Raise exception for unsupported roles"));
221
222 if !forbids_system {
223 caps |= ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
224 }
225
226 let requires_alternation = template.contains("must alternate user and assistant")
229 || template.contains("conversation roles must alternate")
230 || template.contains("ns.index % 2");
231
232 if requires_alternation {
233 caps |= ModelCapabilities::REQUIRES_STRICT_TURNS;
234 }
235
236 let has_tool_patterns = template.contains("<tool_call>")
238 || template.contains("<|python_tag|>")
239 || template.contains("if tools")
240 || template.contains("tools is defined")
241 || template.contains("tool_calls")
242 || template.contains("function_call");
243
244 if has_tool_patterns {
245 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
246 tool_detected_from_metadata = true;
247 }
248
249 let has_reasoning_patterns = template.contains("<think>")
251 || template.contains("</think>")
252 || template.contains("<reasoning>")
253 || template.contains("</reasoning>")
254 || template.contains("enable_thinking")
255 || template.contains("thinking_forced_open")
256 || template.contains("reasoning_content");
257
258 if has_reasoning_patterns {
259 caps |= ModelCapabilities::SUPPORTS_REASONING;
260 reasoning_detected_from_metadata = true;
261 }
262 }
263
264 if let Some(name) = model_name {
272 let name_lower = name.to_lowercase();
273
274 if !tool_detected_from_metadata {
276 let has_tool_name = name_lower.contains("hermes")
277 || name_lower.contains("functionary")
278 || name_lower.contains("firefunction")
279 || name_lower.contains("gorilla");
280
281 if has_tool_name {
282 caps |= ModelCapabilities::SUPPORTS_TOOL_CALLS;
283 }
284 }
285
286 if !reasoning_detected_from_metadata {
288 let has_reasoning_name = name_lower.contains("deepseek-r1")
289 || name_lower.contains("qwq")
290 || name_lower.contains("-r1-")
291 || name_lower.contains("o1");
292
293 if has_reasoning_name {
294 caps |= ModelCapabilities::SUPPORTS_REASONING;
295 }
296 }
297 }
298
299 caps
300}
301
302#[must_use]
350pub fn capabilities_from_architecture(arch: Option<&str>) -> ModelCapabilities {
351 let Some(arch) = arch else {
352 return ModelCapabilities::empty();
353 };
354
355 match arch {
356 "mistral" => ModelCapabilities::REQUIRES_STRICT_TURNS,
360
361 "mistral3" => {
367 ModelCapabilities::REQUIRES_STRICT_TURNS | ModelCapabilities::SUPPORTS_SYSTEM_ROLE
368 }
369
370 _ => ModelCapabilities::empty(),
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_default_capabilities() {
384 let caps = ModelCapabilities::default();
385 assert!(caps.is_empty());
387 assert!(!caps.supports_system_role());
388 assert!(!caps.requires_strict_turns());
389 assert!(!caps.supports_tool_calls());
390 assert!(!caps.supports_reasoning());
391 }
392
393 #[test]
394 fn test_infer_openai_style() {
395 let template = r"
396 {% for message in messages %}
397 {{ message.role }}: {{ message.content }}
398 {% endfor %}
399 ";
400 let caps = infer_from_chat_template(Some(template), None);
401 assert!(caps.supports_system_role());
402 assert!(!caps.requires_strict_turns());
403 }
404
405 #[test]
406 fn test_infer_mistral_style() {
407 let template = r"
408 {% if message.role == 'system' %}
409 {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
410 {% endif %}
411 {% if (message['role'] == 'user') != (ns.index % 2 == 0) %}
412 {{ raise_exception('conversation roles must alternate user and assistant') }}
413 {% endif %}
414 ";
415 let caps = infer_from_chat_template(Some(template), None);
416 assert!(!caps.supports_system_role());
417 assert!(caps.requires_strict_turns());
418 }
419
420 #[test]
421 fn test_infer_missing_template() {
422 let caps = infer_from_chat_template(None, None);
423 assert!(caps.is_empty());
425 assert!(!caps.supports_system_role());
426 }
427
428 #[test]
429 fn test_tool_calling_from_template() {
430 let template = r"
431 {% if tools %}
432 <tool_call>{{ message.tool_calls }}</tool_call>
433 {% endif %}
434 ";
435 let caps = infer_from_chat_template(Some(template), None);
436 assert!(caps.supports_tool_calls());
437 }
438
439 #[test]
440 fn test_reasoning_from_template() {
441 let template = r"
442 {% if enable_thinking %}
443 <think>{{ message.thinking }}</think>
444 {% endif %}
445 ";
446 let caps = infer_from_chat_template(Some(template), None);
447 assert!(caps.supports_reasoning());
448 }
449
450 #[test]
451 fn test_tool_calling_name_fallback() {
452 let caps = infer_from_chat_template(None, Some("hermes-2-pro-7b"));
454 assert!(caps.supports_tool_calls());
455 }
456
457 #[test]
458 fn test_reasoning_name_fallback() {
459 let caps = infer_from_chat_template(None, Some("deepseek-r1-lite"));
461 assert!(caps.supports_reasoning());
462 }
463
464 #[test]
465 fn test_metadata_plus_name_fallback() {
466 let template = "simple template with no tool markers";
468 let caps = infer_from_chat_template(Some(template), Some("hermes-model"));
469 assert!(caps.supports_tool_calls());
471 }
472
473 #[test]
474 fn test_metadata_detected_skips_name_fallback() {
475 let template = "<tool_call>detected</tool_call>";
477 let caps = infer_from_chat_template(Some(template), Some("not-a-tool-model"));
478 assert!(caps.supports_tool_calls());
480 }
481
482 #[test]
483 fn test_combined_detections() {
484 let template = r"
485 {% if tools %}<tool_call>{{ tool }}</tool_call>{% endif %}
486 <think>{{ reasoning }}</think>
487 ";
488 let caps = infer_from_chat_template(Some(template), None);
489 assert!(caps.supports_tool_calls());
490 assert!(caps.supports_reasoning());
491 }
492
493 #[test]
496 fn test_arch_none_returns_empty() {
497 assert!(capabilities_from_architecture(None).is_empty());
498 }
499
500 #[test]
501 fn test_arch_mistral_requires_strict_turns() {
502 let caps = capabilities_from_architecture(Some("mistral"));
503 assert!(caps.requires_strict_turns());
504 }
505
506 #[test]
507 fn test_arch_llama_returns_empty() {
508 assert!(capabilities_from_architecture(Some("llama")).is_empty());
509 }
510
511 #[test]
512 fn test_arch_unknown_returns_empty() {
513 assert!(capabilities_from_architecture(Some("future-arch-xyz")).is_empty());
514 }
515
516 #[test]
517 fn test_arch_mistral3_strict_turns_and_system_role() {
518 let caps = capabilities_from_architecture(Some("mistral3"));
519 assert!(
520 caps.requires_strict_turns(),
521 "mistral3 must enforce strict turns"
522 );
523 assert!(
524 caps.supports_system_role(),
525 "mistral3 supports system via [SYSTEM_PROMPT]"
526 );
527 }
528
529 #[test]
530 fn test_infer_mistral_v7_supports_system() {
531 let template = r"
534 {% if messages[0].role == 'system' %}
535 [SYSTEM_PROMPT]{{ messages[0].content }}[/SYSTEM_PROMPT]
536 {% endif %}
537 {% for message in messages %}
538 {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
539 {{ raise_exception('conversation roles must alternate') }}
540 {% endif %}
541 {% endfor %}
542 ";
543 let caps = infer_from_chat_template(Some(template), None);
544 assert!(
545 caps.supports_system_role(),
546 "[SYSTEM_PROMPT] is positive evidence"
547 );
548 assert!(caps.requires_strict_turns(), "still enforces alternation");
549 }
550
551 #[test]
552 fn test_infer_mistral_v3_supports_system() {
553 let template = r"
556 {% if tools is defined %}[AVAILABLE_TOOLS]{{ tools | tojson }}[/AVAILABLE_TOOLS]{% endif %}
557 {% for message in messages %}
558 {% if message.role == 'user' %}[INST]{{ message.content }}[/INST]
559 {% elif message.role == 'assistant' %}{{ message.content }}</s>
560 {% endif %}
561 {% endfor %}
562 ";
563 let caps = infer_from_chat_template(Some(template), None);
564 assert!(
565 caps.supports_system_role(),
566 "[AVAILABLE_TOOLS] is positive evidence"
567 );
568 }
569
570 #[test]
571 fn test_infer_mistral_v1_forbids_system() {
572 let template = r"
575 {% if message.role == 'system' %}
576 {{ raise_exception('Only user, assistant and tool roles are supported, got system.') }}
577 {% endif %}
578 ";
579 let caps = infer_from_chat_template(Some(template), None);
580 assert!(
581 !caps.supports_system_role(),
582 "v1/v2 genuinely rejects system role"
583 );
584 }
585
586 #[test]
587 fn test_arch_or_template_additive() {
588 let template = "<tool_call>{{ tool }}</tool_call>";
591 let from_template = infer_from_chat_template(Some(template), None);
592 let from_arch = capabilities_from_architecture(Some("mistral"));
593 let combined = from_template | from_arch;
594 assert!(combined.supports_tool_calls(), "tool calls from template");
595 assert!(combined.requires_strict_turns(), "strict turns from arch");
596 }
597}
598
599#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
619#[serde(untagged)]
620pub enum MessageContent {
621 Text(String),
623 Parts(Vec<serde_json::Value>),
628}
629
630impl MessageContent {
631 pub fn as_str(&self) -> Option<&str> {
633 match self {
634 Self::Text(s) => Some(s),
635 Self::Parts(_) => None,
636 }
637 }
638
639 pub fn into_string(self) -> String {
650 match self {
651 Self::Text(s) => s,
652 Self::Parts(parts) => parts
653 .iter()
654 .filter_map(|p| p.get("text").and_then(|t| t.as_str()))
655 .collect::<Vec<_>>()
656 .join(""),
657 }
658 }
659
660 fn merge_with(self, other: Self) -> Self {
672 match (self, other) {
673 (Self::Text(mut a), Self::Text(b)) => {
674 if a.is_empty() {
675 return Self::Text(b);
676 }
677 if b.is_empty() {
678 return Self::Text(a);
679 }
680 a.push_str("\n\n");
681 a.push_str(&b);
682 Self::Text(a)
683 }
684 (Self::Parts(mut a), Self::Parts(b)) => {
685 a.extend(b);
686 Self::Parts(a)
687 }
688 (Self::Text(a), Self::Parts(b)) => {
689 let mut parts = vec![serde_json::json!({"type": "text", "text": a})];
690 parts.extend(b);
691 Self::Parts(parts)
692 }
693 (Self::Parts(mut a), Self::Text(b)) => {
694 a.push(serde_json::json!({"type": "text", "text": b}));
695 Self::Parts(a)
696 }
697 }
698 }
699}
700
701impl From<String> for MessageContent {
702 fn from(s: String) -> Self {
703 Self::Text(s)
704 }
705}
706
707impl From<&str> for MessageContent {
708 fn from(s: &str) -> Self {
709 Self::Text(s.to_string())
710 }
711}
712
713#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
719pub struct ChatMessage {
720 pub role: String,
721 #[serde(default, skip_serializing_if = "Option::is_none")]
722 pub content: Option<MessageContent>,
723 #[serde(default, skip_serializing_if = "Option::is_none")]
724 pub tool_calls: Option<serde_json::Value>,
725}
726
727impl ChatMessage {
728 fn merge_into(&mut self, other: Self) {
734 self.content = match (self.content.take(), other.content) {
735 (None, b) => b,
736 (a, None) => a,
737 (Some(a), Some(b)) => Some(a.merge_with(b)),
738 };
739 match (self.tool_calls.as_mut(), other.tool_calls) {
740 (_, None) => {}
741 (None, tc) => self.tool_calls = tc,
742 (Some(last_tc), Some(msg_tc)) => {
743 if let (Some(la), Some(ma)) = (last_tc.as_array_mut(), msg_tc.as_array()) {
744 la.extend_from_slice(ma);
745 }
746 }
747 }
748 }
749}
750
751fn merge_consecutive_system_messages(messages: Vec<ChatMessage>) -> Vec<ChatMessage> {
766 if messages.is_empty() {
767 return messages;
768 }
769
770 let mut result: Vec<ChatMessage> = Vec::with_capacity(messages.len());
771
772 for msg in messages {
773 let is_system_merge = result
774 .last()
775 .is_some_and(|last| last.role == "system" && msg.role == "system");
776 if is_system_merge {
777 let last = result.last_mut().unwrap();
778 last.content = match (last.content.take(), msg.content) {
779 (None, b) => b,
780 (a, None) => a,
781 (Some(a), Some(b)) => Some(a.merge_with(b)),
782 };
783 } else {
784 result.push(msg);
785 }
786 }
787
788 result
789}
790
791pub fn transform_messages_for_capabilities(
815 mut messages: Vec<ChatMessage>,
816 capabilities: ModelCapabilities,
817) -> Vec<ChatMessage> {
818 messages = merge_consecutive_system_messages(messages);
823
824 if capabilities.is_empty() {
826 return messages;
827 }
828
829 if !capabilities.contains(ModelCapabilities::SUPPORTS_SYSTEM_ROLE) {
831 for msg in &mut messages {
832 if msg.role == "system" {
833 msg.role = "user".to_string();
834 if let Some(content) = msg.content.take() {
835 msg.content = Some(MessageContent::Text(format!(
836 "[System]: {}",
837 content.into_string()
838 )));
839 }
840 }
841 }
842 }
843
844 if capabilities.contains(ModelCapabilities::REQUIRES_STRICT_TURNS) {
846 let mut merged: Vec<ChatMessage> = Vec::new();
847 for msg in messages {
848 let is_mergeable = msg.role == "user" || msg.role == "assistant";
849 let same_role_as_last = merged.last().is_some_and(|last| last.role == msg.role);
850 if is_mergeable && same_role_as_last {
851 merged.last_mut().unwrap().merge_into(msg);
852 } else {
853 merged.push(msg);
854 }
855 }
856 return merged;
857 }
858
859 messages
860}
861
862#[cfg(test)]
863mod transform_tests {
864 use super::*;
865
866 #[test]
867 fn test_transform_unknown_passes_through_non_system() {
868 let messages = vec![
870 ChatMessage {
871 role: "user".to_string(),
872 content: Some(MessageContent::Text("Hello".to_string())),
873 tool_calls: None,
874 },
875 ChatMessage {
876 role: "assistant".to_string(),
877 content: Some(MessageContent::Text("Hi there".to_string())),
878 tool_calls: None,
879 },
880 ];
881 let original = messages.clone();
882 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
883 assert_eq!(result, original);
884 }
885
886 #[test]
887 fn test_merges_consecutive_system_messages_always() {
888 let messages = vec![
890 ChatMessage {
891 role: "system".to_string(),
892 content: Some(MessageContent::Text(
893 "You are a helpful assistant.".to_string(),
894 )),
895 tool_calls: None,
896 },
897 ChatMessage {
898 role: "system".to_string(),
899 content: Some(MessageContent::Text(
900 "WORKING_MEMORY:\n- task1 (ok): done".to_string(),
901 )),
902 tool_calls: None,
903 },
904 ChatMessage {
905 role: "user".to_string(),
906 content: Some(MessageContent::Text("Hello".to_string())),
907 tool_calls: None,
908 },
909 ];
910 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
911
912 assert_eq!(result.len(), 2);
913 assert_eq!(result[0].role, "system");
914 assert_eq!(
915 result[0].content.as_ref().and_then(|c| c.as_str()),
916 Some("You are a helpful assistant.\n\nWORKING_MEMORY:\n- task1 (ok): done")
917 );
918 assert_eq!(result[1].role, "user");
919 }
920
921 #[test]
922 fn test_merges_three_consecutive_system_messages() {
923 let messages = vec![
924 ChatMessage {
925 role: "system".to_string(),
926 content: Some(MessageContent::Text("First.".to_string())),
927 tool_calls: None,
928 },
929 ChatMessage {
930 role: "system".to_string(),
931 content: Some(MessageContent::Text("Second.".to_string())),
932 tool_calls: None,
933 },
934 ChatMessage {
935 role: "system".to_string(),
936 content: Some(MessageContent::Text("Third.".to_string())),
937 tool_calls: None,
938 },
939 ];
940 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
941
942 assert_eq!(result.len(), 1);
943 assert_eq!(
944 result[0].content.as_ref().and_then(|c| c.as_str()),
945 Some("First.\n\nSecond.\n\nThird.")
946 );
947 }
948
949 #[test]
950 fn test_handles_empty_system_content() {
951 let messages = vec![
952 ChatMessage {
953 role: "system".to_string(),
954 content: Some(MessageContent::Text(String::new())),
955 tool_calls: None,
956 },
957 ChatMessage {
958 role: "system".to_string(),
959 content: Some(MessageContent::Text("Actual content".to_string())),
960 tool_calls: None,
961 },
962 ];
963 let result = transform_messages_for_capabilities(messages, ModelCapabilities::empty());
964
965 assert_eq!(result.len(), 1);
966 assert_eq!(
967 result[0].content.as_ref().and_then(|c| c.as_str()),
968 Some("Actual content")
969 );
970 }
971
972 #[test]
973 fn test_transform_system_to_user() {
974 let messages = vec![
975 ChatMessage {
976 role: "system".to_string(),
977 content: Some(MessageContent::Text("You are helpful".to_string())),
978 tool_calls: None,
979 },
980 ChatMessage {
981 role: "user".to_string(),
982 content: Some(MessageContent::Text("Hello".to_string())),
983 tool_calls: None,
984 },
985 ];
986 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
988 let result = transform_messages_for_capabilities(messages, caps);
989 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
992 let content_str = result[0].content.as_ref().and_then(|c| c.as_str()).unwrap();
993 assert!(content_str.contains("[System]: You are helpful"));
994 assert!(content_str.contains("Hello"));
995 }
996
997 #[test]
998 fn test_transform_preserves_system_when_supported() {
999 let messages = vec![ChatMessage {
1000 role: "system".to_string(),
1001 content: Some(MessageContent::Text("You are helpful".to_string())),
1002 tool_calls: None,
1003 }];
1004 let caps = ModelCapabilities::SUPPORTS_SYSTEM_ROLE;
1005 let result = transform_messages_for_capabilities(messages, caps);
1006 assert_eq!(result[0].role, "system");
1007 assert_eq!(
1008 result[0].content,
1009 Some(MessageContent::Text("You are helpful".to_string()))
1010 );
1011 }
1012
1013 #[test]
1014 fn test_transform_merges_consecutive_user_messages() {
1015 let messages = vec![
1016 ChatMessage {
1017 role: "user".to_string(),
1018 content: Some(MessageContent::Text("First".to_string())),
1019 tool_calls: None,
1020 },
1021 ChatMessage {
1022 role: "user".to_string(),
1023 content: Some(MessageContent::Text("Second".to_string())),
1024 tool_calls: None,
1025 },
1026 ];
1027 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1028 let result = transform_messages_for_capabilities(messages, caps);
1029 assert_eq!(result.len(), 1);
1030 assert_eq!(
1031 result[0].content,
1032 Some(MessageContent::Text("First\n\nSecond".to_string()))
1033 );
1034 }
1035
1036 #[test]
1037 fn test_transform_does_not_merge_tool_messages() {
1038 let messages = vec![
1039 ChatMessage {
1040 role: "tool".to_string(),
1041 content: Some(MessageContent::Text("Result 1".to_string())),
1042 tool_calls: None,
1043 },
1044 ChatMessage {
1045 role: "tool".to_string(),
1046 content: Some(MessageContent::Text("Result 2".to_string())),
1047 tool_calls: None,
1048 },
1049 ];
1050 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1051 let result = transform_messages_for_capabilities(messages, caps);
1052 assert_eq!(result.len(), 2); }
1054
1055 #[test]
1056 fn test_transform_combined_system_and_merge() {
1057 let messages = vec![
1058 ChatMessage {
1059 role: "system".to_string(),
1060 content: Some(MessageContent::Text("Be helpful".to_string())),
1061 tool_calls: None,
1062 },
1063 ChatMessage {
1064 role: "user".to_string(),
1065 content: Some(MessageContent::Text("First".to_string())),
1066 tool_calls: None,
1067 },
1068 ChatMessage {
1069 role: "user".to_string(),
1070 content: Some(MessageContent::Text("Second".to_string())),
1071 tool_calls: None,
1072 },
1073 ];
1074 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS; let result = transform_messages_for_capabilities(messages, caps);
1076 assert_eq!(result.len(), 1); assert_eq!(result[0].role, "user");
1078 let content_str = result[0].content.as_ref().and_then(|c| c.as_str()).unwrap();
1079 assert!(content_str.contains("[System]: Be helpful"));
1080 assert!(content_str.contains("First"));
1081 assert!(content_str.contains("Second"));
1082 }
1083
1084 #[test]
1085 fn test_merge_consecutive_assistant_with_tool_calls() {
1086 let tool_call_1 = serde_json::json!([
1089 {
1090 "id": "call_1",
1091 "type": "function",
1092 "function": {
1093 "name": "get_weather",
1094 "arguments": "{\"location\":\"Paris\"}"
1095 }
1096 }
1097 ]);
1098 let tool_call_2 = serde_json::json!([
1099 {
1100 "id": "call_2",
1101 "type": "function",
1102 "function": {
1103 "name": "get_time",
1104 "arguments": "{\"timezone\":\"UTC\"}"
1105 }
1106 }
1107 ]);
1108
1109 let messages = vec![
1110 ChatMessage {
1111 role: "user".to_string(),
1112 content: Some(MessageContent::Text("What's the weather?".to_string())),
1113 tool_calls: None,
1114 },
1115 ChatMessage {
1116 role: "assistant".to_string(),
1117 content: Some(MessageContent::Text("Let me check...".to_string())),
1118 tool_calls: Some(tool_call_1),
1119 },
1120 ChatMessage {
1121 role: "assistant".to_string(),
1122 content: Some(MessageContent::Text("And the time...".to_string())),
1123 tool_calls: Some(tool_call_2),
1124 },
1125 ];
1126
1127 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1128 let result = transform_messages_for_capabilities(messages, caps);
1129
1130 assert_eq!(result.len(), 2);
1132 assert_eq!(result[0].role, "user");
1133 assert_eq!(result[1].role, "assistant");
1134
1135 assert_eq!(
1137 result[1].content,
1138 Some(MessageContent::Text(
1139 "Let me check...\n\nAnd the time...".to_string()
1140 ))
1141 );
1142
1143 let merged_tool_calls = result[1].tool_calls.as_ref().unwrap();
1145 let tool_calls_array = merged_tool_calls.as_array().unwrap();
1146 assert_eq!(tool_calls_array.len(), 2);
1147 assert_eq!(tool_calls_array[0]["id"], "call_1");
1148 assert_eq!(tool_calls_array[1]["id"], "call_2");
1149 }
1150
1151 #[test]
1152 fn test_merge_assistant_messages_only_first_has_content() {
1153 let tool_call = serde_json::json!([
1155 {
1156 "id": "call_1",
1157 "type": "function",
1158 "function": {
1159 "name": "get_weather",
1160 "arguments": "{}"
1161 }
1162 }
1163 ]);
1164
1165 let messages = vec![
1166 ChatMessage {
1167 role: "assistant".to_string(),
1168 content: Some(MessageContent::Text("Let me check...".to_string())),
1169 tool_calls: None,
1170 },
1171 ChatMessage {
1172 role: "assistant".to_string(),
1173 content: None,
1174 tool_calls: Some(tool_call),
1175 },
1176 ];
1177
1178 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1179 let result = transform_messages_for_capabilities(messages, caps);
1180
1181 assert_eq!(result.len(), 1);
1182 assert_eq!(
1183 result[0].content,
1184 Some(MessageContent::Text("Let me check...".to_string()))
1185 );
1186 assert!(result[0].tool_calls.is_some());
1187 }
1188
1189 #[test]
1190 fn test_merge_assistant_messages_only_second_has_content() {
1191 let tool_call = serde_json::json!([
1193 {
1194 "id": "call_1",
1195 "type": "function",
1196 "function": {
1197 "name": "get_weather",
1198 "arguments": "{}"
1199 }
1200 }
1201 ]);
1202
1203 let messages = vec![
1204 ChatMessage {
1205 role: "assistant".to_string(),
1206 content: None,
1207 tool_calls: Some(tool_call),
1208 },
1209 ChatMessage {
1210 role: "assistant".to_string(),
1211 content: Some(MessageContent::Text("Result received".to_string())),
1212 tool_calls: None,
1213 },
1214 ];
1215
1216 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1217 let result = transform_messages_for_capabilities(messages, caps);
1218
1219 assert_eq!(result.len(), 1);
1220 assert_eq!(
1221 result[0].content,
1222 Some(MessageContent::Text("Result received".to_string()))
1223 );
1224 assert!(result[0].tool_calls.is_some());
1225 }
1226
1227 #[test]
1228 fn test_merge_assistant_messages_neither_has_content() {
1229 let tool_call_1 = serde_json::json!([
1231 {
1232 "id": "call_1",
1233 "type": "function",
1234 "function": {"name": "tool1", "arguments": "{}"}
1235 }
1236 ]);
1237 let tool_call_2 = serde_json::json!([
1238 {
1239 "id": "call_2",
1240 "type": "function",
1241 "function": {"name": "tool2", "arguments": "{}"}
1242 }
1243 ]);
1244
1245 let messages = vec![
1246 ChatMessage {
1247 role: "assistant".to_string(),
1248 content: None,
1249 tool_calls: Some(tool_call_1),
1250 },
1251 ChatMessage {
1252 role: "assistant".to_string(),
1253 content: None,
1254 tool_calls: Some(tool_call_2),
1255 },
1256 ];
1257
1258 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1259 let result = transform_messages_for_capabilities(messages, caps);
1260
1261 assert_eq!(result.len(), 1);
1262 assert!(result[0].content.is_none());
1263
1264 let merged_tool_calls = result[0].tool_calls.as_ref().unwrap();
1265 let tool_calls_array = merged_tool_calls.as_array().unwrap();
1266 assert_eq!(tool_calls_array.len(), 2);
1267 }
1268
1269 #[test]
1270 fn test_no_merge_without_strict_turns_capability() {
1271 let messages = vec![
1273 ChatMessage {
1274 role: "assistant".to_string(),
1275 content: Some(MessageContent::Text("First".to_string())),
1276 tool_calls: None,
1277 },
1278 ChatMessage {
1279 role: "assistant".to_string(),
1280 content: Some(MessageContent::Text("Second".to_string())),
1281 tool_calls: None,
1282 },
1283 ];
1284
1285 let caps = ModelCapabilities::empty();
1286 let result = transform_messages_for_capabilities(messages, caps);
1287
1288 assert_eq!(result.len(), 2);
1290 }
1291
1292 #[test]
1293 fn test_merge_preserves_different_role_boundaries() {
1294 let tool_call = serde_json::json!([
1296 {
1297 "id": "call_1",
1298 "type": "function",
1299 "function": {"name": "tool1", "arguments": "{}"}
1300 }
1301 ]);
1302
1303 let messages = vec![
1304 ChatMessage {
1305 role: "user".to_string(),
1306 content: Some(MessageContent::Text("Question".to_string())),
1307 tool_calls: None,
1308 },
1309 ChatMessage {
1310 role: "assistant".to_string(),
1311 content: Some(MessageContent::Text("Answer".to_string())),
1312 tool_calls: Some(tool_call),
1313 },
1314 ChatMessage {
1315 role: "user".to_string(),
1316 content: Some(MessageContent::Text("Follow-up".to_string())),
1317 tool_calls: None,
1318 },
1319 ];
1320
1321 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1322 let result = transform_messages_for_capabilities(messages, caps);
1323
1324 assert_eq!(result.len(), 3);
1326 assert_eq!(result[0].role, "user");
1327 assert_eq!(result[1].role, "assistant");
1328 assert_eq!(result[2].role, "user");
1329 }
1330
1331 #[test]
1334 fn test_array_content_deserializes_to_parts() {
1335 let json = serde_json::json!({
1336 "role": "user",
1337 "content": [
1338 {"type": "text", "text": "What is in this image?"},
1339 {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}
1340 ]
1341 });
1342 let msg: ChatMessage = serde_json::from_value(json).unwrap();
1343 assert!(matches!(msg.content, Some(MessageContent::Parts(_))));
1344 let re_serialised = serde_json::to_value(&msg).unwrap();
1346 assert!(re_serialised["content"].is_array());
1347 }
1348
1349 #[test]
1352 fn test_merge_array_content_with_text_content() {
1353 let messages = vec![
1354 ChatMessage {
1355 role: "user".to_string(),
1356 content: Some(MessageContent::Parts(vec![
1357 serde_json::json!({"type": "text", "text": "Look at this:"}),
1358 serde_json::json!({"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}),
1359 ])),
1360 tool_calls: None,
1361 },
1362 ChatMessage {
1363 role: "user".to_string(),
1364 content: Some(MessageContent::Text("What do you see?".to_string())),
1365 tool_calls: None,
1366 },
1367 ];
1368 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1369 let result = transform_messages_for_capabilities(messages, caps);
1370
1371 assert_eq!(result.len(), 1);
1372 assert!(matches!(&result[0].content, Some(MessageContent::Parts(p)) if p.len() == 3));
1374 }
1375
1376 #[test]
1379 fn test_tool_message_with_array_content_passes_through() {
1380 let messages = vec![
1381 ChatMessage {
1382 role: "user".to_string(),
1383 content: Some(MessageContent::Text("Run the tool".to_string())),
1384 tool_calls: None,
1385 },
1386 ChatMessage {
1387 role: "tool".to_string(),
1388 content: Some(MessageContent::Parts(vec![
1389 serde_json::json!({"type": "text", "text": "tool result here"}),
1390 ])),
1391 tool_calls: None,
1392 },
1393 ChatMessage {
1394 role: "user".to_string(),
1395 content: Some(MessageContent::Text("Thanks".to_string())),
1396 tool_calls: None,
1397 },
1398 ];
1399 let caps = ModelCapabilities::REQUIRES_STRICT_TURNS;
1400 let result = transform_messages_for_capabilities(messages, caps);
1401
1402 assert_eq!(result.len(), 3);
1405 assert_eq!(result[1].role, "tool");
1406 assert!(matches!(&result[1].content, Some(MessageContent::Parts(_))));
1407 }
1408}