1use crate::domain::{Model, NewModel};
4use crate::ports::{CoreError, GgufParserPort, ModelRepository, RepositoryError};
5use std::path::Path;
6use std::sync::Arc;
7
8pub struct ModelService {
14 repo: Arc<dyn ModelRepository>,
15}
16
17impl ModelService {
18 pub fn new(repo: Arc<dyn ModelRepository>) -> Self {
20 Self { repo }
21 }
22
23 pub async fn list(&self) -> Result<Vec<Model>, CoreError> {
25 self.repo.list().await.map_err(CoreError::from)
26 }
27
28 pub async fn get(&self, identifier: &str) -> Result<Option<Model>, CoreError> {
30 if let Ok(id) = identifier.parse::<i64>() {
32 match self.repo.get_by_id(id).await {
33 Ok(model) => return Ok(Some(model)),
34 Err(RepositoryError::NotFound(_)) => {}
35 Err(e) => return Err(CoreError::from(e)),
36 }
37 }
38 match self.repo.get_by_name(identifier).await {
40 Ok(model) => Ok(Some(model)),
41 Err(RepositoryError::NotFound(_)) => Ok(None),
42 Err(e) => Err(CoreError::from(e)),
43 }
44 }
45
46 pub async fn get_by_id(&self, id: i64) -> Result<Option<Model>, CoreError> {
48 match self.repo.get_by_id(id).await {
49 Ok(model) => Ok(Some(model)),
50 Err(RepositoryError::NotFound(_)) => Ok(None),
51 Err(e) => Err(CoreError::from(e)),
52 }
53 }
54
55 pub async fn get_by_name(&self, name: &str) -> Result<Option<Model>, CoreError> {
57 match self.repo.get_by_name(name).await {
58 Ok(model) => Ok(Some(model)),
59 Err(RepositoryError::NotFound(_)) => Ok(None),
60 Err(e) => Err(CoreError::from(e)),
61 }
62 }
63
64 pub async fn find_by_identifier(&self, identifier: &str) -> Result<Model, CoreError> {
67 self.get(identifier)
68 .await?
69 .ok_or_else(|| CoreError::Validation(format!("Model not found: {identifier}")))
70 }
71
72 pub async fn find_by_name(&self, name: &str) -> Result<Model, CoreError> {
74 self.get_by_name(name)
75 .await?
76 .ok_or_else(|| CoreError::Validation(format!("Model not found: {name}")))
77 }
78
79 pub async fn add(&self, model: NewModel) -> Result<Model, CoreError> {
81 self.repo.insert(&model).await.map_err(CoreError::from)
82 }
83
84 pub async fn import_from_file(
109 &self,
110 file_path: &Path,
111 gguf_parser: &dyn GgufParserPort,
112 param_count_override: Option<f64>,
113 ) -> Result<Model, CoreError> {
114 let gguf_metadata = crate::utils::validation::validate_and_parse_gguf(
116 gguf_parser,
117 file_path
118 .to_str()
119 .ok_or_else(|| CoreError::Validation("Invalid file path encoding".to_string()))?,
120 )
121 .map_err(|e| CoreError::Validation(format!("GGUF validation failed: {e}")))?;
122
123 let param_count_b = param_count_override
125 .or(gguf_metadata.param_count_b)
126 .unwrap_or(0.0);
127
128 let gguf_capabilities = gguf_parser.detect_capabilities(&gguf_metadata);
130 let auto_tags = gguf_capabilities.to_tags();
131
132 let template = gguf_metadata.metadata.get("tokenizer.chat_template");
134 let name = gguf_metadata.metadata.get("general.name");
135 let model_capabilities = crate::domain::infer_from_chat_template(
136 template.map(String::as_str),
137 name.map(String::as_str),
138 );
139
140 let new_model = NewModel {
142 name: name.cloned().unwrap_or_else(|| {
143 file_path
144 .file_stem()
145 .and_then(|s| s.to_str())
146 .unwrap_or("Unknown Model")
147 .to_string()
148 }),
149 file_path: file_path.to_path_buf(),
150 param_count_b,
151 architecture: gguf_metadata.architecture,
152 quantization: gguf_metadata.quantization,
153 context_length: gguf_metadata.context_length,
154 metadata: gguf_metadata.metadata,
155 added_at: chrono::Utc::now(),
156 hf_repo_id: None,
157 hf_commit_sha: None,
158 hf_filename: None,
159 download_date: None,
160 last_update_check: None,
161 tags: auto_tags,
162 file_paths: None,
163 capabilities: model_capabilities,
164 inference_defaults: None,
165 };
166
167 self.repo.insert(&new_model).await.map_err(CoreError::from)
169 }
170
171 pub async fn update(&self, model: &Model) -> Result<(), CoreError> {
173 self.repo.update(model).await.map_err(CoreError::from)
174 }
175
176 pub async fn delete(&self, id: i64) -> Result<(), CoreError> {
178 self.repo.delete(id).await.map_err(CoreError::from)
179 }
180
181 pub async fn remove(&self, identifier: &str) -> Result<Model, CoreError> {
183 let model = self.find_by_identifier(identifier).await?;
184 self.repo.delete(model.id).await.map_err(CoreError::from)?;
185 Ok(model)
186 }
187
188 pub async fn list_tags(&self) -> Result<Vec<String>, CoreError> {
194 let models = self.repo.list().await.map_err(CoreError::from)?;
195 let mut all_tags = std::collections::HashSet::new();
196 for model in models {
197 for tag in model.tags {
198 all_tags.insert(tag);
199 }
200 }
201 let mut tags: Vec<String> = all_tags.into_iter().collect();
202 tags.sort();
203 Ok(tags)
204 }
205
206 pub async fn add_tag(&self, model_id: i64, tag: String) -> Result<(), CoreError> {
210 let mut model = self
211 .repo
212 .get_by_id(model_id)
213 .await
214 .map_err(CoreError::from)?;
215 if !model.tags.contains(&tag) {
216 model.tags.push(tag);
217 model.tags.sort();
218 self.repo.update(&model).await.map_err(CoreError::from)?;
219 }
220 Ok(())
221 }
222
223 pub async fn remove_tag(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
227 let mut model = self
228 .repo
229 .get_by_id(model_id)
230 .await
231 .map_err(CoreError::from)?;
232 model.tags.retain(|t| t != tag);
233 self.repo.update(&model).await.map_err(CoreError::from)?;
234 Ok(())
235 }
236
237 pub async fn get_tags(&self, model_id: i64) -> Result<Vec<String>, CoreError> {
239 let model = self
240 .repo
241 .get_by_id(model_id)
242 .await
243 .map_err(CoreError::from)?;
244 Ok(model.tags)
245 }
246
247 pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Model>, CoreError> {
249 let models = self.repo.list().await.map_err(CoreError::from)?;
250 Ok(models
251 .into_iter()
252 .filter(|m| m.tags.contains(&tag.to_string()))
253 .collect())
254 }
255
256 pub async fn get_filter_options(&self) -> Result<crate::domain::ModelFilterOptions, CoreError> {
268 use crate::domain::{ModelFilterOptions, RangeValues};
269 use std::collections::HashSet;
270
271 let models = self.repo.list().await.map_err(CoreError::from)?;
272
273 let mut quantizations: Vec<String> = models
275 .iter()
276 .filter_map(|m| m.quantization.clone())
277 .filter(|q| !q.is_empty())
278 .collect::<HashSet<_>>()
279 .into_iter()
280 .collect();
281 quantizations.sort();
282
283 let param_range = if models.is_empty() {
285 None
286 } else {
287 let min = models
288 .iter()
289 .map(|m| m.param_count_b)
290 .fold(f64::INFINITY, f64::min);
291 let max = models
292 .iter()
293 .map(|m| m.param_count_b)
294 .fold(f64::NEG_INFINITY, f64::max);
295 if min.is_finite() && max.is_finite() {
296 Some(RangeValues { min, max })
297 } else {
298 None
299 }
300 };
301
302 let context_lengths: Vec<u64> = models.iter().filter_map(|m| m.context_length).collect();
304 #[allow(clippy::cast_precision_loss)]
305 let context_range = if context_lengths.is_empty() {
306 None
307 } else {
308 let min = *context_lengths.iter().min().unwrap() as f64;
309 let max = *context_lengths.iter().max().unwrap() as f64;
310 Some(RangeValues { min, max })
311 };
312
313 Ok(ModelFilterOptions {
314 quantizations,
315 param_range,
316 context_range,
317 })
318 }
319
320 pub async fn bootstrap_capabilities(&self) -> Result<(), CoreError> {
333 use crate::domain::infer_from_chat_template;
334
335 let models = self.repo.list().await.map_err(CoreError::from)?;
336
337 for mut model in models {
338 if model.capabilities.is_empty() {
340 let template = model.metadata.get("tokenizer.chat_template");
341 let name = model.metadata.get("general.name");
342 let inferred = infer_from_chat_template(
343 template.map(String::as_str),
344 name.map(String::as_str),
345 );
346
347 model.capabilities = inferred;
348 self.repo.update(&model).await.map_err(CoreError::from)?;
349 }
350 }
351
352 Ok(())
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use crate::ports::{ModelRepository, RepositoryError};
360 use async_trait::async_trait;
361 use chrono::Utc;
362
363 use std::path::PathBuf;
364 use std::sync::Mutex;
365
366 struct MockRepo {
367 models: Mutex<Vec<Model>>,
368 }
369
370 impl MockRepo {
371 fn new() -> Self {
372 Self {
373 models: Mutex::new(vec![]),
374 }
375 }
376 }
377
378 #[async_trait]
379 impl ModelRepository for MockRepo {
380 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
381 Ok(self.models.lock().unwrap().clone())
382 }
383
384 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
385 self.models
386 .lock()
387 .unwrap()
388 .iter()
389 .find(|m| m.id == id)
390 .cloned()
391 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
392 }
393
394 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
395 self.models
396 .lock()
397 .unwrap()
398 .iter()
399 .find(|m| m.name == name)
400 .cloned()
401 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
402 }
403
404 #[allow(clippy::cast_possible_wrap, clippy::significant_drop_tightening)]
405 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
406 let mut models = self.models.lock().unwrap();
407 let id = models.len() as i64 + 1;
408 let created = Model {
409 id,
410 name: model.name.clone(),
411 file_path: model.file_path.clone(),
412 param_count_b: model.param_count_b,
413 architecture: model.architecture.clone(),
414 quantization: model.quantization.clone(),
415 context_length: model.context_length,
416 metadata: model.metadata.clone(),
417 added_at: model.added_at,
418 hf_repo_id: model.hf_repo_id.clone(),
419 hf_commit_sha: model.hf_commit_sha.clone(),
420 hf_filename: model.hf_filename.clone(),
421 download_date: model.download_date,
422 last_update_check: model.last_update_check,
423 tags: model.tags.clone(),
424 capabilities: model.capabilities,
425 inference_defaults: model.inference_defaults.clone(),
426 };
427 models.push(created.clone());
428 Ok(created)
429 }
430
431 async fn update(&self, model: &Model) -> Result<(), RepositoryError> {
432 let mut models = self.models.lock().unwrap();
433 models.iter_mut().find(|m| m.id == model.id).map_or_else(
434 || Err(RepositoryError::NotFound(format!("id={}", model.id))),
435 |m| {
436 m.clone_from(model);
437 Ok(())
438 },
439 )
440 }
441
442 async fn delete(&self, id: i64) -> Result<(), RepositoryError> {
443 let mut models = self.models.lock().unwrap();
444 let len_before = models.len();
445 models.retain(|m| m.id != id);
446 if models.len() == len_before {
447 Err(RepositoryError::NotFound(format!("id={id}")))
448 } else {
449 Ok(())
450 }
451 }
452 }
453
454 #[tokio::test]
455 async fn test_list_empty() {
456 let repo = Arc::new(MockRepo::new());
457 let service = ModelService::new(repo);
458 let models = service.list().await.unwrap();
459 assert!(models.is_empty());
460 }
461
462 #[tokio::test]
463 async fn test_add_and_get() {
464 let repo = Arc::new(MockRepo::new());
465 let service = ModelService::new(repo);
466
467 let new_model = NewModel::new(
468 "test-model".to_string(),
469 PathBuf::from("/path/to/model.gguf"),
470 7.0,
471 Utc::now(),
472 );
473
474 let created = service.add(new_model).await.unwrap();
475 assert_eq!(created.name, "test-model");
476
477 let found = service.get_by_name("test-model").await.unwrap();
478 assert!(found.is_some());
479 assert_eq!(found.unwrap().id, created.id);
480 }
481
482 #[tokio::test]
483 async fn test_find_by_identifier_not_found() {
484 let repo = Arc::new(MockRepo::new());
485 let service = ModelService::new(repo);
486
487 let result = service.find_by_identifier("nonexistent").await;
488 assert!(result.is_err());
489 }
490
491 #[tokio::test]
492 async fn test_get_filter_options_empty() {
493 let repo = Arc::new(MockRepo::new());
494 let service = ModelService::new(repo);
495
496 let options = service.get_filter_options().await.unwrap();
497 assert!(options.quantizations.is_empty());
498 assert!(options.param_range.is_none());
499 assert!(options.context_range.is_none());
500 }
501
502 #[tokio::test]
503 async fn test_get_filter_options_with_models() {
504 let repo = Arc::new(MockRepo::new());
505 let service = ModelService::new(repo);
506
507 let mut model1 = NewModel::new(
509 "model-1".to_string(),
510 PathBuf::from("/path/to/model1.gguf"),
511 7.0,
512 Utc::now(),
513 );
514 model1.quantization = Some("Q4_K_M".to_string());
515 model1.context_length = Some(4096);
516
517 let mut model2 = NewModel::new(
518 "model-2".to_string(),
519 PathBuf::from("/path/to/model2.gguf"),
520 13.0,
521 Utc::now(),
522 );
523 model2.quantization = Some("Q8_0".to_string());
524 model2.context_length = Some(8192);
525
526 let mut model3 = NewModel::new(
527 "model-3".to_string(),
528 PathBuf::from("/path/to/model3.gguf"),
529 70.0,
530 Utc::now(),
531 );
532 model3.quantization = Some("Q4_K_M".to_string()); service.add(model1).await.unwrap();
536 service.add(model2).await.unwrap();
537 service.add(model3).await.unwrap();
538
539 let options = service.get_filter_options().await.unwrap();
540
541 assert_eq!(options.quantizations, vec!["Q4_K_M", "Q8_0"]);
543
544 let param_range = options.param_range.unwrap();
546 assert!((param_range.min - 7.0).abs() < 0.001);
547 assert!((param_range.max - 70.0).abs() < 0.001);
548
549 let context_range = options.context_range.unwrap();
551 assert!((context_range.min - 4096.0).abs() < 0.001);
552 assert!((context_range.max - 8192.0).abs() < 0.001);
553 }
554}