1use crate::domain::{Model, NewModel};
4use crate::ports::{CoreError, GgufParserPort, ModelRepository, RepositoryError};
5use std::path::Path;
6use std::sync::Arc;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct RetagDiff {
12 pub added: Vec<String>,
14 pub removed: Vec<String>,
16}
17
18impl RetagDiff {
19 pub const fn is_changed(&self) -> bool {
21 !self.added.is_empty() || !self.removed.is_empty()
22 }
23}
24
25pub struct ModelService {
31 repo: Arc<dyn ModelRepository>,
32}
33
34impl ModelService {
35 pub fn new(repo: Arc<dyn ModelRepository>) -> Self {
37 Self { repo }
38 }
39
40 pub async fn list(&self) -> Result<Vec<Model>, CoreError> {
42 self.repo.list().await.map_err(CoreError::from)
43 }
44
45 pub async fn get(&self, identifier: &str) -> Result<Option<Model>, CoreError> {
47 if let Ok(id) = identifier.parse::<i64>() {
49 match self.repo.get_by_id(id).await {
50 Ok(model) => return Ok(Some(model)),
51 Err(RepositoryError::NotFound(_)) => {}
52 Err(e) => return Err(CoreError::from(e)),
53 }
54 }
55 match self.repo.get_by_name(identifier).await {
57 Ok(model) => Ok(Some(model)),
58 Err(RepositoryError::NotFound(_)) => Ok(None),
59 Err(e) => Err(CoreError::from(e)),
60 }
61 }
62
63 pub async fn get_by_id(&self, id: i64) -> Result<Option<Model>, CoreError> {
65 match self.repo.get_by_id(id).await {
66 Ok(model) => Ok(Some(model)),
67 Err(RepositoryError::NotFound(_)) => Ok(None),
68 Err(e) => Err(CoreError::from(e)),
69 }
70 }
71
72 pub async fn get_by_name(&self, name: &str) -> Result<Option<Model>, CoreError> {
74 match self.repo.get_by_name(name).await {
75 Ok(model) => Ok(Some(model)),
76 Err(RepositoryError::NotFound(_)) => Ok(None),
77 Err(e) => Err(CoreError::from(e)),
78 }
79 }
80
81 pub async fn find_by_identifier(&self, identifier: &str) -> Result<Model, CoreError> {
84 self.get(identifier)
85 .await?
86 .ok_or_else(|| CoreError::Validation(format!("Model not found: {identifier}")))
87 }
88
89 pub async fn tags_for(&self, identifier: &str) -> Vec<String> {
96 match self.get(identifier).await {
97 Ok(Some(m)) => m.tags,
98 _ => Vec::new(),
99 }
100 }
101
102 pub async fn find_by_name(&self, name: &str) -> Result<Model, CoreError> {
104 self.get_by_name(name)
105 .await?
106 .ok_or_else(|| CoreError::Validation(format!("Model not found: {name}")))
107 }
108
109 pub async fn add(&self, model: NewModel) -> Result<Model, CoreError> {
111 self.repo.insert(&model).await.map_err(CoreError::from)
112 }
113
114 pub async fn import_from_file(
139 &self,
140 file_path: &Path,
141 gguf_parser: &dyn GgufParserPort,
142 param_count_override: Option<f64>,
143 ) -> Result<Model, CoreError> {
144 let gguf_metadata = crate::utils::validation::validate_and_parse_gguf(
146 gguf_parser,
147 file_path
148 .to_str()
149 .ok_or_else(|| CoreError::Validation("Invalid file path encoding".to_string()))?,
150 )
151 .map_err(|e| CoreError::Validation(format!("GGUF validation failed: {e}")))?;
152
153 let param_count_b = param_count_override
155 .or(gguf_metadata.param_count_b)
156 .unwrap_or(0.0);
157
158 let gguf_capabilities = gguf_parser.detect_capabilities(&gguf_metadata);
160 let auto_tags = gguf_capabilities.to_tags();
161
162 let template = gguf_metadata.metadata.get("tokenizer.chat_template");
167 let name = gguf_metadata.metadata.get("general.name");
168 let from_template = crate::domain::infer_from_chat_template(
169 template.map(String::as_str),
170 name.map(String::as_str),
171 );
172 let from_arch =
173 crate::domain::capabilities_from_architecture(gguf_metadata.architecture.as_deref());
174 let model_capabilities = from_template | from_arch;
175
176 let new_model = NewModel {
178 name: name.cloned().unwrap_or_else(|| {
179 file_path
180 .file_stem()
181 .and_then(|s| s.to_str())
182 .unwrap_or("Unknown Model")
183 .to_string()
184 }),
185 file_path: file_path.to_path_buf(),
186 param_count_b,
187 architecture: gguf_metadata.architecture,
188 quantization: gguf_metadata.quantization,
189 context_length: gguf_metadata.context_length,
190 expert_count: gguf_metadata.expert_count,
191 expert_used_count: gguf_metadata.expert_used_count,
192 expert_shared_count: gguf_metadata.expert_shared_count,
193 metadata: gguf_metadata.metadata,
194 added_at: chrono::Utc::now(),
195 hf_repo_id: None,
196 hf_commit_sha: None,
197 hf_filename: None,
198 download_date: None,
199 last_update_check: None,
200 tags: auto_tags,
201 file_paths: None,
202 capabilities: model_capabilities,
203 inference_defaults: None,
204 };
205
206 self.repo.insert(&new_model).await.map_err(CoreError::from)
208 }
209
210 pub async fn update(&self, model: &Model) -> Result<(), CoreError> {
212 self.repo.update(model).await.map_err(CoreError::from)
213 }
214
215 pub async fn delete(&self, id: i64) -> Result<(), CoreError> {
217 self.repo.delete(id).await.map_err(CoreError::from)
218 }
219
220 pub async fn remove(&self, identifier: &str) -> Result<Model, CoreError> {
222 let model = self.find_by_identifier(identifier).await?;
223 self.repo.delete(model.id).await.map_err(CoreError::from)?;
224 Ok(model)
225 }
226
227 pub async fn list_tags(&self) -> Result<Vec<String>, CoreError> {
233 let models = self.repo.list().await.map_err(CoreError::from)?;
234 let mut all_tags = std::collections::HashSet::new();
235 for model in models {
236 for tag in model.tags {
237 all_tags.insert(tag);
238 }
239 }
240 let mut tags: Vec<String> = all_tags.into_iter().collect();
241 tags.sort();
242 Ok(tags)
243 }
244
245 pub async fn add_tag(&self, model_id: i64, tag: String) -> Result<(), CoreError> {
249 let mut model = self
250 .repo
251 .get_by_id(model_id)
252 .await
253 .map_err(CoreError::from)?;
254 if !model.tags.contains(&tag) {
255 model.tags.push(tag);
256 model.tags.sort();
257 self.repo.update(&model).await.map_err(CoreError::from)?;
258 }
259 Ok(())
260 }
261
262 pub async fn remove_tag(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
269 if crate::domain::is_system_tag(tag) {
270 return Err(CoreError::Validation(format!(
271 "tag '{tag}' is a system tag and cannot be removed via the standard API",
272 )));
273 }
274 self.remove_tag_force(model_id, tag).await
275 }
276
277 pub async fn remove_tag_force(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
284 let mut model = self
285 .repo
286 .get_by_id(model_id)
287 .await
288 .map_err(CoreError::from)?;
289 model.tags.retain(|t| t != tag);
290 self.repo.update(&model).await.map_err(CoreError::from)?;
291 Ok(())
292 }
293
294 pub async fn get_tags(&self, model_id: i64) -> Result<Vec<String>, CoreError> {
296 let model = self
297 .repo
298 .get_by_id(model_id)
299 .await
300 .map_err(CoreError::from)?;
301 Ok(model.tags)
302 }
303
304 pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Model>, CoreError> {
306 let models = self.repo.list().await.map_err(CoreError::from)?;
307 Ok(models
308 .into_iter()
309 .filter(|m| m.tags.contains(&tag.to_string()))
310 .collect())
311 }
312
313 pub async fn get_filter_options(&self) -> Result<crate::domain::ModelFilterOptions, CoreError> {
325 use crate::domain::{ModelFilterOptions, RangeValues};
326 use std::collections::HashSet;
327
328 let models = self.repo.list().await.map_err(CoreError::from)?;
329
330 let mut quantizations: Vec<String> = models
332 .iter()
333 .filter_map(|m| m.quantization.clone())
334 .filter(|q| !q.is_empty())
335 .collect::<HashSet<_>>()
336 .into_iter()
337 .collect();
338 quantizations.sort();
339
340 let param_range = if models.is_empty() {
342 None
343 } else {
344 let min = models
345 .iter()
346 .map(|m| m.param_count_b)
347 .fold(f64::INFINITY, f64::min);
348 let max = models
349 .iter()
350 .map(|m| m.param_count_b)
351 .fold(f64::NEG_INFINITY, f64::max);
352 if min.is_finite() && max.is_finite() {
353 Some(RangeValues { min, max })
354 } else {
355 None
356 }
357 };
358
359 let context_lengths: Vec<u64> = models.iter().filter_map(|m| m.context_length).collect();
361 #[allow(clippy::cast_precision_loss)]
362 let context_range = if context_lengths.is_empty() {
363 None
364 } else {
365 let min = *context_lengths.iter().min().unwrap() as f64;
366 let max = *context_lengths.iter().max().unwrap() as f64;
367 Some(RangeValues { min, max })
368 };
369
370 Ok(ModelFilterOptions {
371 quantizations,
372 param_range,
373 context_range,
374 })
375 }
376
377 pub async fn bootstrap_capabilities(&self) -> Result<(), CoreError> {
390 use crate::domain::{capabilities_from_architecture, infer_from_chat_template};
391
392 let models = self.repo.list().await.map_err(CoreError::from)?;
393
394 for mut model in models {
395 if model.capabilities.is_empty() {
397 let template = model.metadata.get("tokenizer.chat_template");
398 let name = model.metadata.get("general.name");
399 let arch = model.metadata.get("general.architecture");
400 let from_template = infer_from_chat_template(
401 template.map(String::as_str),
402 name.map(String::as_str),
403 );
404 let from_arch = capabilities_from_architecture(arch.map(String::as_str));
405 model.capabilities = from_template | from_arch;
406 self.repo.update(&model).await.map_err(CoreError::from)?;
407 }
408 }
409
410 Ok(())
411 }
412
413 pub async fn retag_model(
433 &self,
434 model_id: i64,
435 gguf_parser: &dyn GgufParserPort,
436 full: bool,
437 ) -> Result<Option<RetagDiff>, CoreError> {
438 let mut model = self
439 .repo
440 .get_by_id(model_id)
441 .await
442 .map_err(CoreError::from)?;
443
444 let gguf_metadata = crate::domain::gguf::GgufMetadata {
447 metadata: model.metadata.clone(),
448 ..Default::default()
449 };
450 let new_tags = gguf_parser.detect_capabilities(&gguf_metadata).to_tags();
451
452 let before: std::collections::BTreeSet<String> = model.tags.iter().cloned().collect();
453
454 if full {
455 const AUTO_TAG_NAMES: &[&str] = &["reasoning", "agent", "vision", "code", "moe"];
457 model.tags.retain(|t| {
458 !AUTO_TAG_NAMES.contains(&t.as_str()) && !crate::domain::is_system_tag(t)
459 });
460 }
461
462 for t in &new_tags {
463 if !model.tags.contains(t) {
464 model.tags.push(t.clone());
465 }
466 }
467 model.tags.sort();
468
469 let after: std::collections::BTreeSet<String> = model.tags.iter().cloned().collect();
470 if after == before {
471 return Ok(None);
472 }
473
474 self.repo.update(&model).await.map_err(CoreError::from)?;
475 Ok(Some(RetagDiff {
476 added: after.difference(&before).cloned().collect(),
477 removed: before.difference(&after).cloned().collect(),
478 }))
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::ports::{ModelRepository, RepositoryError};
486 use async_trait::async_trait;
487 use chrono::Utc;
488
489 use std::path::PathBuf;
490 use std::sync::Mutex;
491
492 struct MockRepo {
493 models: Mutex<Vec<Model>>,
494 }
495
496 impl MockRepo {
497 fn new() -> Self {
498 Self {
499 models: Mutex::new(vec![]),
500 }
501 }
502 }
503
504 #[async_trait]
505 impl ModelRepository for MockRepo {
506 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
507 Ok(self.models.lock().unwrap().clone())
508 }
509
510 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
511 self.models
512 .lock()
513 .unwrap()
514 .iter()
515 .find(|m| m.id == id)
516 .cloned()
517 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
518 }
519
520 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
521 self.models
522 .lock()
523 .unwrap()
524 .iter()
525 .find(|m| m.name == name)
526 .cloned()
527 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
528 }
529
530 #[allow(clippy::cast_possible_wrap, clippy::significant_drop_tightening)]
531 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
532 let mut models = self.models.lock().unwrap();
533 let id = models.len() as i64 + 1;
534 let created = Model {
535 id,
536 name: model.name.clone(),
537 file_path: model.file_path.clone(),
538 param_count_b: model.param_count_b,
539 architecture: model.architecture.clone(),
540 quantization: model.quantization.clone(),
541 context_length: model.context_length,
542 expert_count: model.expert_count,
543 expert_used_count: model.expert_used_count,
544 expert_shared_count: model.expert_shared_count,
545 metadata: model.metadata.clone(),
546 added_at: model.added_at,
547 hf_repo_id: model.hf_repo_id.clone(),
548 hf_commit_sha: model.hf_commit_sha.clone(),
549 hf_filename: model.hf_filename.clone(),
550 download_date: model.download_date,
551 last_update_check: model.last_update_check,
552 tags: model.tags.clone(),
553 capabilities: model.capabilities,
554 inference_defaults: model.inference_defaults.clone(),
555 };
556 models.push(created.clone());
557 Ok(created)
558 }
559
560 async fn update(&self, model: &Model) -> Result<(), RepositoryError> {
561 let mut models = self.models.lock().unwrap();
562 models.iter_mut().find(|m| m.id == model.id).map_or_else(
563 || Err(RepositoryError::NotFound(format!("id={}", model.id))),
564 |m| {
565 m.clone_from(model);
566 Ok(())
567 },
568 )
569 }
570
571 async fn delete(&self, id: i64) -> Result<(), RepositoryError> {
572 let mut models = self.models.lock().unwrap();
573 let len_before = models.len();
574 models.retain(|m| m.id != id);
575 if models.len() == len_before {
576 Err(RepositoryError::NotFound(format!("id={id}")))
577 } else {
578 Ok(())
579 }
580 }
581 }
582
583 #[tokio::test]
584 async fn test_list_empty() {
585 let repo = Arc::new(MockRepo::new());
586 let service = ModelService::new(repo);
587 let models = service.list().await.unwrap();
588 assert!(models.is_empty());
589 }
590
591 #[tokio::test]
592 async fn test_add_and_get() {
593 let repo = Arc::new(MockRepo::new());
594 let service = ModelService::new(repo);
595
596 let new_model = NewModel::new(
597 "test-model".to_string(),
598 PathBuf::from("/path/to/model.gguf"),
599 7.0,
600 Utc::now(),
601 );
602
603 let created = service.add(new_model).await.unwrap();
604 assert_eq!(created.name, "test-model");
605
606 let found = service.get_by_name("test-model").await.unwrap();
607 assert!(found.is_some());
608 assert_eq!(found.unwrap().id, created.id);
609 }
610
611 #[tokio::test]
612 async fn test_find_by_identifier_not_found() {
613 let repo = Arc::new(MockRepo::new());
614 let service = ModelService::new(repo);
615
616 let result = service.find_by_identifier("nonexistent").await;
617 assert!(result.is_err());
618 }
619
620 #[tokio::test]
621 async fn test_get_filter_options_empty() {
622 let repo = Arc::new(MockRepo::new());
623 let service = ModelService::new(repo);
624
625 let options = service.get_filter_options().await.unwrap();
626 assert!(options.quantizations.is_empty());
627 assert!(options.param_range.is_none());
628 assert!(options.context_range.is_none());
629 }
630
631 #[tokio::test]
632 async fn test_get_filter_options_with_models() {
633 let repo = Arc::new(MockRepo::new());
634 let service = ModelService::new(repo);
635
636 let mut model1 = NewModel::new(
638 "model-1".to_string(),
639 PathBuf::from("/path/to/model1.gguf"),
640 7.0,
641 Utc::now(),
642 );
643 model1.quantization = Some("Q4_K_M".to_string());
644 model1.context_length = Some(4096);
645
646 let mut model2 = NewModel::new(
647 "model-2".to_string(),
648 PathBuf::from("/path/to/model2.gguf"),
649 13.0,
650 Utc::now(),
651 );
652 model2.quantization = Some("Q8_0".to_string());
653 model2.context_length = Some(8192);
654
655 let mut model3 = NewModel::new(
656 "model-3".to_string(),
657 PathBuf::from("/path/to/model3.gguf"),
658 70.0,
659 Utc::now(),
660 );
661 model3.quantization = Some("Q4_K_M".to_string()); service.add(model1).await.unwrap();
665 service.add(model2).await.unwrap();
666 service.add(model3).await.unwrap();
667
668 let options = service.get_filter_options().await.unwrap();
669
670 assert_eq!(options.quantizations, vec!["Q4_K_M", "Q8_0"]);
672
673 let param_range = options.param_range.unwrap();
675 assert!((param_range.min - 7.0).abs() < 0.001);
676 assert!((param_range.max - 70.0).abs() < 0.001);
677
678 let context_range = options.context_range.unwrap();
680 assert!((context_range.min - 4096.0).abs() < 0.001);
681 assert!((context_range.max - 8192.0).abs() < 0.001);
682 }
683
684 #[tokio::test]
685 async fn test_remove_tag_rejects_system_tag() {
686 let repo = Arc::new(MockRepo::new());
687 let service = ModelService::new(repo);
688
689 let mut new_model = NewModel::new(
690 "qwen-test".to_string(),
691 PathBuf::from("/path/to/m.gguf"),
692 7.0,
693 Utc::now(),
694 );
695 new_model.tags = vec!["chat".to_string(), "format:qwen-xml".to_string()];
696 let created = service.add(new_model).await.unwrap();
697
698 let err = service
700 .remove_tag(created.id, "format:qwen-xml")
701 .await
702 .unwrap_err();
703 assert!(matches!(err, CoreError::Validation(_)));
704
705 let tags = service.get_tags(created.id).await.unwrap();
707 assert!(tags.contains(&"format:qwen-xml".to_string()));
708
709 service
711 .remove_tag_force(created.id, "format:qwen-xml")
712 .await
713 .unwrap();
714 let tags = service.get_tags(created.id).await.unwrap();
715 assert!(!tags.contains(&"format:qwen-xml".to_string()));
716 }
717
718 #[tokio::test]
719 async fn test_remove_tag_allows_user_tag() {
720 let repo = Arc::new(MockRepo::new());
721 let service = ModelService::new(repo);
722
723 let mut new_model =
724 NewModel::new("u".to_string(), PathBuf::from("/p.gguf"), 7.0, Utc::now());
725 new_model.tags = vec!["chat".to_string(), "format:hermes".to_string()];
726 let created = service.add(new_model).await.unwrap();
727
728 service.remove_tag(created.id, "chat").await.unwrap();
729 let tags = service.get_tags(created.id).await.unwrap();
730 assert_eq!(tags, vec!["format:hermes".to_string()]);
731 }
732
733 struct StubCapsParser {
735 tags: Vec<String>,
736 }
737
738 impl crate::ports::GgufParserPort for StubCapsParser {
739 fn parse(
740 &self,
741 _file_path: &std::path::Path,
742 ) -> std::result::Result<crate::ports::GgufMetadata, crate::ports::GgufParseError> {
743 Ok(crate::ports::GgufMetadata::default())
744 }
745
746 fn detect_capabilities(
747 &self,
748 _metadata: &crate::ports::GgufMetadata,
749 ) -> crate::ports::GgufCapabilities {
750 let mut extensions = std::collections::BTreeSet::new();
751 for t in &self.tags {
752 extensions.insert(t.clone());
753 }
754 crate::ports::GgufCapabilities {
755 flags: crate::domain::gguf::CapabilityFlags::empty(),
756 extensions,
757 }
758 }
759 }
760
761 #[tokio::test]
762 async fn test_retag_additive_appends_missing_tags() {
763 let repo = Arc::new(MockRepo::new());
764 let service = ModelService::new(repo);
765
766 let mut new_model =
767 NewModel::new("m".to_string(), PathBuf::from("/p.gguf"), 7.0, Utc::now());
768 new_model.tags = vec!["chat".to_string()];
769 let created = service.add(new_model).await.unwrap();
770
771 let parser = StubCapsParser {
772 tags: vec!["format:qwen-xml".to_string()],
773 };
774 let diff = service
775 .retag_model(created.id, &parser, false)
776 .await
777 .unwrap();
778 assert_eq!(diff.unwrap().added, vec!["format:qwen-xml".to_string()]);
779
780 let tags = service.get_tags(created.id).await.unwrap();
781 assert!(tags.contains(&"chat".to_string()));
782 assert!(tags.contains(&"format:qwen-xml".to_string()));
783 }
784
785 #[tokio::test]
786 async fn test_retag_additive_noop_when_already_present() {
787 let repo = Arc::new(MockRepo::new());
788 let service = ModelService::new(repo);
789
790 let mut new_model =
791 NewModel::new("m".to_string(), PathBuf::from("/p.gguf"), 7.0, Utc::now());
792 new_model.tags = vec!["format:qwen-xml".to_string()];
793 let created = service.add(new_model).await.unwrap();
794
795 let parser = StubCapsParser {
796 tags: vec!["format:qwen-xml".to_string()],
797 };
798 let diff = service
799 .retag_model(created.id, &parser, false)
800 .await
801 .unwrap();
802 assert!(diff.is_none());
803 }
804
805 #[tokio::test]
806 async fn test_retag_full_replaces_auto_tags_preserves_user() {
807 let repo = Arc::new(MockRepo::new());
808 let service = ModelService::new(repo);
809
810 let mut new_model =
811 NewModel::new("m".to_string(), PathBuf::from("/p.gguf"), 7.0, Utc::now());
812 new_model.tags = vec![
813 "favorite".to_string(), "format:hermes".to_string(), "reasoning".to_string(), ];
817 let created = service.add(new_model).await.unwrap();
818
819 let parser = StubCapsParser {
820 tags: vec!["format:qwen-xml".to_string()],
821 };
822 service
823 .retag_model(created.id, &parser, true)
824 .await
825 .unwrap();
826
827 let tags = service.get_tags(created.id).await.unwrap();
828 assert!(tags.contains(&"favorite".to_string()));
829 assert!(tags.contains(&"format:qwen-xml".to_string()));
830 assert!(!tags.contains(&"format:hermes".to_string()));
831 assert!(!tags.contains(&"reasoning".to_string()));
832 }
833}