1use crate::domain::{Model, NewModel};
4use crate::ports::{CoreError, GgufParserPort, ModelRepository, RepositoryError};
5use std::path::Path;
6use std::sync::Arc;
7
8pub struct ModelService {
14 repo: Arc<dyn ModelRepository>,
15}
16
17impl ModelService {
18 pub fn new(repo: Arc<dyn ModelRepository>) -> Self {
20 Self { repo }
21 }
22
23 pub async fn list(&self) -> Result<Vec<Model>, CoreError> {
25 self.repo.list().await.map_err(CoreError::from)
26 }
27
28 pub async fn get(&self, identifier: &str) -> Result<Option<Model>, CoreError> {
30 if let Ok(id) = identifier.parse::<i64>() {
32 match self.repo.get_by_id(id).await {
33 Ok(model) => return Ok(Some(model)),
34 Err(RepositoryError::NotFound(_)) => {}
35 Err(e) => return Err(CoreError::from(e)),
36 }
37 }
38 match self.repo.get_by_name(identifier).await {
40 Ok(model) => Ok(Some(model)),
41 Err(RepositoryError::NotFound(_)) => Ok(None),
42 Err(e) => Err(CoreError::from(e)),
43 }
44 }
45
46 pub async fn get_by_id(&self, id: i64) -> Result<Option<Model>, CoreError> {
48 match self.repo.get_by_id(id).await {
49 Ok(model) => Ok(Some(model)),
50 Err(RepositoryError::NotFound(_)) => Ok(None),
51 Err(e) => Err(CoreError::from(e)),
52 }
53 }
54
55 pub async fn get_by_name(&self, name: &str) -> Result<Option<Model>, CoreError> {
57 match self.repo.get_by_name(name).await {
58 Ok(model) => Ok(Some(model)),
59 Err(RepositoryError::NotFound(_)) => Ok(None),
60 Err(e) => Err(CoreError::from(e)),
61 }
62 }
63
64 pub async fn find_by_identifier(&self, identifier: &str) -> Result<Model, CoreError> {
67 self.get(identifier)
68 .await?
69 .ok_or_else(|| CoreError::Validation(format!("Model not found: {identifier}")))
70 }
71
72 pub async fn find_by_name(&self, name: &str) -> Result<Model, CoreError> {
74 self.get_by_name(name)
75 .await?
76 .ok_or_else(|| CoreError::Validation(format!("Model not found: {name}")))
77 }
78
79 pub async fn add(&self, model: NewModel) -> Result<Model, CoreError> {
81 self.repo.insert(&model).await.map_err(CoreError::from)
82 }
83
84 pub async fn import_from_file(
109 &self,
110 file_path: &Path,
111 gguf_parser: &dyn GgufParserPort,
112 param_count_override: Option<f64>,
113 ) -> Result<Model, CoreError> {
114 let gguf_metadata = crate::utils::validation::validate_and_parse_gguf(
116 gguf_parser,
117 file_path
118 .to_str()
119 .ok_or_else(|| CoreError::Validation("Invalid file path encoding".to_string()))?,
120 )
121 .map_err(|e| CoreError::Validation(format!("GGUF validation failed: {e}")))?;
122
123 let param_count_b = param_count_override
125 .or(gguf_metadata.param_count_b)
126 .unwrap_or(0.0);
127
128 let gguf_capabilities = gguf_parser.detect_capabilities(&gguf_metadata);
130 let auto_tags = gguf_capabilities.to_tags();
131
132 let template = gguf_metadata.metadata.get("tokenizer.chat_template");
134 let name = gguf_metadata.metadata.get("general.name");
135 let model_capabilities = crate::domain::infer_from_chat_template(
136 template.map(String::as_str),
137 name.map(String::as_str),
138 );
139
140 let new_model = NewModel {
142 name: name.cloned().unwrap_or_else(|| {
143 file_path
144 .file_stem()
145 .and_then(|s| s.to_str())
146 .unwrap_or("Unknown Model")
147 .to_string()
148 }),
149 file_path: file_path.to_path_buf(),
150 param_count_b,
151 architecture: gguf_metadata.architecture,
152 quantization: gguf_metadata.quantization,
153 context_length: gguf_metadata.context_length,
154 expert_count: gguf_metadata.expert_count,
155 expert_used_count: gguf_metadata.expert_used_count,
156 expert_shared_count: gguf_metadata.expert_shared_count,
157 metadata: gguf_metadata.metadata,
158 added_at: chrono::Utc::now(),
159 hf_repo_id: None,
160 hf_commit_sha: None,
161 hf_filename: None,
162 download_date: None,
163 last_update_check: None,
164 tags: auto_tags,
165 file_paths: None,
166 capabilities: model_capabilities,
167 inference_defaults: None,
168 };
169
170 self.repo.insert(&new_model).await.map_err(CoreError::from)
172 }
173
174 pub async fn update(&self, model: &Model) -> Result<(), CoreError> {
176 self.repo.update(model).await.map_err(CoreError::from)
177 }
178
179 pub async fn delete(&self, id: i64) -> Result<(), CoreError> {
181 self.repo.delete(id).await.map_err(CoreError::from)
182 }
183
184 pub async fn remove(&self, identifier: &str) -> Result<Model, CoreError> {
186 let model = self.find_by_identifier(identifier).await?;
187 self.repo.delete(model.id).await.map_err(CoreError::from)?;
188 Ok(model)
189 }
190
191 pub async fn list_tags(&self) -> Result<Vec<String>, CoreError> {
197 let models = self.repo.list().await.map_err(CoreError::from)?;
198 let mut all_tags = std::collections::HashSet::new();
199 for model in models {
200 for tag in model.tags {
201 all_tags.insert(tag);
202 }
203 }
204 let mut tags: Vec<String> = all_tags.into_iter().collect();
205 tags.sort();
206 Ok(tags)
207 }
208
209 pub async fn add_tag(&self, model_id: i64, tag: String) -> Result<(), CoreError> {
213 let mut model = self
214 .repo
215 .get_by_id(model_id)
216 .await
217 .map_err(CoreError::from)?;
218 if !model.tags.contains(&tag) {
219 model.tags.push(tag);
220 model.tags.sort();
221 self.repo.update(&model).await.map_err(CoreError::from)?;
222 }
223 Ok(())
224 }
225
226 pub async fn remove_tag(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
230 let mut model = self
231 .repo
232 .get_by_id(model_id)
233 .await
234 .map_err(CoreError::from)?;
235 model.tags.retain(|t| t != tag);
236 self.repo.update(&model).await.map_err(CoreError::from)?;
237 Ok(())
238 }
239
240 pub async fn get_tags(&self, model_id: i64) -> Result<Vec<String>, CoreError> {
242 let model = self
243 .repo
244 .get_by_id(model_id)
245 .await
246 .map_err(CoreError::from)?;
247 Ok(model.tags)
248 }
249
250 pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Model>, CoreError> {
252 let models = self.repo.list().await.map_err(CoreError::from)?;
253 Ok(models
254 .into_iter()
255 .filter(|m| m.tags.contains(&tag.to_string()))
256 .collect())
257 }
258
259 pub async fn get_filter_options(&self) -> Result<crate::domain::ModelFilterOptions, CoreError> {
271 use crate::domain::{ModelFilterOptions, RangeValues};
272 use std::collections::HashSet;
273
274 let models = self.repo.list().await.map_err(CoreError::from)?;
275
276 let mut quantizations: Vec<String> = models
278 .iter()
279 .filter_map(|m| m.quantization.clone())
280 .filter(|q| !q.is_empty())
281 .collect::<HashSet<_>>()
282 .into_iter()
283 .collect();
284 quantizations.sort();
285
286 let param_range = if models.is_empty() {
288 None
289 } else {
290 let min = models
291 .iter()
292 .map(|m| m.param_count_b)
293 .fold(f64::INFINITY, f64::min);
294 let max = models
295 .iter()
296 .map(|m| m.param_count_b)
297 .fold(f64::NEG_INFINITY, f64::max);
298 if min.is_finite() && max.is_finite() {
299 Some(RangeValues { min, max })
300 } else {
301 None
302 }
303 };
304
305 let context_lengths: Vec<u64> = models.iter().filter_map(|m| m.context_length).collect();
307 #[allow(clippy::cast_precision_loss)]
308 let context_range = if context_lengths.is_empty() {
309 None
310 } else {
311 let min = *context_lengths.iter().min().unwrap() as f64;
312 let max = *context_lengths.iter().max().unwrap() as f64;
313 Some(RangeValues { min, max })
314 };
315
316 Ok(ModelFilterOptions {
317 quantizations,
318 param_range,
319 context_range,
320 })
321 }
322
323 pub async fn bootstrap_capabilities(&self) -> Result<(), CoreError> {
336 use crate::domain::infer_from_chat_template;
337
338 let models = self.repo.list().await.map_err(CoreError::from)?;
339
340 for mut model in models {
341 if model.capabilities.is_empty() {
343 let template = model.metadata.get("tokenizer.chat_template");
344 let name = model.metadata.get("general.name");
345 let inferred = infer_from_chat_template(
346 template.map(String::as_str),
347 name.map(String::as_str),
348 );
349
350 model.capabilities = inferred;
351 self.repo.update(&model).await.map_err(CoreError::from)?;
352 }
353 }
354
355 Ok(())
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::ports::{ModelRepository, RepositoryError};
363 use async_trait::async_trait;
364 use chrono::Utc;
365
366 use std::path::PathBuf;
367 use std::sync::Mutex;
368
369 struct MockRepo {
370 models: Mutex<Vec<Model>>,
371 }
372
373 impl MockRepo {
374 fn new() -> Self {
375 Self {
376 models: Mutex::new(vec![]),
377 }
378 }
379 }
380
381 #[async_trait]
382 impl ModelRepository for MockRepo {
383 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
384 Ok(self.models.lock().unwrap().clone())
385 }
386
387 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
388 self.models
389 .lock()
390 .unwrap()
391 .iter()
392 .find(|m| m.id == id)
393 .cloned()
394 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
395 }
396
397 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
398 self.models
399 .lock()
400 .unwrap()
401 .iter()
402 .find(|m| m.name == name)
403 .cloned()
404 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
405 }
406
407 #[allow(clippy::cast_possible_wrap, clippy::significant_drop_tightening)]
408 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
409 let mut models = self.models.lock().unwrap();
410 let id = models.len() as i64 + 1;
411 let created = Model {
412 id,
413 name: model.name.clone(),
414 file_path: model.file_path.clone(),
415 param_count_b: model.param_count_b,
416 architecture: model.architecture.clone(),
417 quantization: model.quantization.clone(),
418 context_length: model.context_length,
419 expert_count: model.expert_count,
420 expert_used_count: model.expert_used_count,
421 expert_shared_count: model.expert_shared_count,
422 metadata: model.metadata.clone(),
423 added_at: model.added_at,
424 hf_repo_id: model.hf_repo_id.clone(),
425 hf_commit_sha: model.hf_commit_sha.clone(),
426 hf_filename: model.hf_filename.clone(),
427 download_date: model.download_date,
428 last_update_check: model.last_update_check,
429 tags: model.tags.clone(),
430 capabilities: model.capabilities,
431 inference_defaults: model.inference_defaults.clone(),
432 };
433 models.push(created.clone());
434 Ok(created)
435 }
436
437 async fn update(&self, model: &Model) -> Result<(), RepositoryError> {
438 let mut models = self.models.lock().unwrap();
439 models.iter_mut().find(|m| m.id == model.id).map_or_else(
440 || Err(RepositoryError::NotFound(format!("id={}", model.id))),
441 |m| {
442 m.clone_from(model);
443 Ok(())
444 },
445 )
446 }
447
448 async fn delete(&self, id: i64) -> Result<(), RepositoryError> {
449 let mut models = self.models.lock().unwrap();
450 let len_before = models.len();
451 models.retain(|m| m.id != id);
452 if models.len() == len_before {
453 Err(RepositoryError::NotFound(format!("id={id}")))
454 } else {
455 Ok(())
456 }
457 }
458 }
459
460 #[tokio::test]
461 async fn test_list_empty() {
462 let repo = Arc::new(MockRepo::new());
463 let service = ModelService::new(repo);
464 let models = service.list().await.unwrap();
465 assert!(models.is_empty());
466 }
467
468 #[tokio::test]
469 async fn test_add_and_get() {
470 let repo = Arc::new(MockRepo::new());
471 let service = ModelService::new(repo);
472
473 let new_model = NewModel::new(
474 "test-model".to_string(),
475 PathBuf::from("/path/to/model.gguf"),
476 7.0,
477 Utc::now(),
478 );
479
480 let created = service.add(new_model).await.unwrap();
481 assert_eq!(created.name, "test-model");
482
483 let found = service.get_by_name("test-model").await.unwrap();
484 assert!(found.is_some());
485 assert_eq!(found.unwrap().id, created.id);
486 }
487
488 #[tokio::test]
489 async fn test_find_by_identifier_not_found() {
490 let repo = Arc::new(MockRepo::new());
491 let service = ModelService::new(repo);
492
493 let result = service.find_by_identifier("nonexistent").await;
494 assert!(result.is_err());
495 }
496
497 #[tokio::test]
498 async fn test_get_filter_options_empty() {
499 let repo = Arc::new(MockRepo::new());
500 let service = ModelService::new(repo);
501
502 let options = service.get_filter_options().await.unwrap();
503 assert!(options.quantizations.is_empty());
504 assert!(options.param_range.is_none());
505 assert!(options.context_range.is_none());
506 }
507
508 #[tokio::test]
509 async fn test_get_filter_options_with_models() {
510 let repo = Arc::new(MockRepo::new());
511 let service = ModelService::new(repo);
512
513 let mut model1 = NewModel::new(
515 "model-1".to_string(),
516 PathBuf::from("/path/to/model1.gguf"),
517 7.0,
518 Utc::now(),
519 );
520 model1.quantization = Some("Q4_K_M".to_string());
521 model1.context_length = Some(4096);
522
523 let mut model2 = NewModel::new(
524 "model-2".to_string(),
525 PathBuf::from("/path/to/model2.gguf"),
526 13.0,
527 Utc::now(),
528 );
529 model2.quantization = Some("Q8_0".to_string());
530 model2.context_length = Some(8192);
531
532 let mut model3 = NewModel::new(
533 "model-3".to_string(),
534 PathBuf::from("/path/to/model3.gguf"),
535 70.0,
536 Utc::now(),
537 );
538 model3.quantization = Some("Q4_K_M".to_string()); service.add(model1).await.unwrap();
542 service.add(model2).await.unwrap();
543 service.add(model3).await.unwrap();
544
545 let options = service.get_filter_options().await.unwrap();
546
547 assert_eq!(options.quantizations, vec!["Q4_K_M", "Q8_0"]);
549
550 let param_range = options.param_range.unwrap();
552 assert!((param_range.min - 7.0).abs() < 0.001);
553 assert!((param_range.max - 70.0).abs() < 0.001);
554
555 let context_range = options.context_range.unwrap();
557 assert!((context_range.min - 4096.0).abs() < 0.001);
558 assert!((context_range.max - 8192.0).abs() < 0.001);
559 }
560}