gglib_core/services/
model_service.rs

1//! Model service - orchestrates model CRUD operations.
2
3use crate::domain::{Model, NewModel};
4use crate::ports::{CoreError, GgufParserPort, ModelRepository, RepositoryError};
5use std::path::Path;
6use std::sync::Arc;
7
8/// Service for model operations.
9///
10/// This service provides high-level model management by delegating
11/// to the injected `ModelRepository`. It adds no business logic
12/// beyond what the repository provides - it's a thin facade.
13pub struct ModelService {
14    repo: Arc<dyn ModelRepository>,
15}
16
17impl ModelService {
18    /// Create a new model service with the given repository.
19    pub fn new(repo: Arc<dyn ModelRepository>) -> Self {
20        Self { repo }
21    }
22
23    /// List all models.
24    pub async fn list(&self) -> Result<Vec<Model>, CoreError> {
25        self.repo.list().await.map_err(CoreError::from)
26    }
27
28    /// Get a model by its identifier (id, name, or HF ID).
29    pub async fn get(&self, identifier: &str) -> Result<Option<Model>, CoreError> {
30        // Try by ID first
31        if let Ok(id) = identifier.parse::<i64>() {
32            match self.repo.get_by_id(id).await {
33                Ok(model) => return Ok(Some(model)),
34                Err(RepositoryError::NotFound(_)) => {}
35                Err(e) => return Err(CoreError::from(e)),
36            }
37        }
38        // Try by name
39        match self.repo.get_by_name(identifier).await {
40            Ok(model) => Ok(Some(model)),
41            Err(RepositoryError::NotFound(_)) => Ok(None),
42            Err(e) => Err(CoreError::from(e)),
43        }
44    }
45
46    /// Get a model by its database ID.
47    pub async fn get_by_id(&self, id: i64) -> Result<Option<Model>, CoreError> {
48        match self.repo.get_by_id(id).await {
49            Ok(model) => Ok(Some(model)),
50            Err(RepositoryError::NotFound(_)) => Ok(None),
51            Err(e) => Err(CoreError::from(e)),
52        }
53    }
54
55    /// Get a model by name.
56    pub async fn get_by_name(&self, name: &str) -> Result<Option<Model>, CoreError> {
57        match self.repo.get_by_name(name).await {
58            Ok(model) => Ok(Some(model)),
59            Err(RepositoryError::NotFound(_)) => Ok(None),
60            Err(e) => Err(CoreError::from(e)),
61        }
62    }
63
64    /// Find a model by identifier (id, name, or HF ID).
65    /// Returns error if not found.
66    pub async fn find_by_identifier(&self, identifier: &str) -> Result<Model, CoreError> {
67        self.get(identifier)
68            .await?
69            .ok_or_else(|| CoreError::Validation(format!("Model not found: {identifier}")))
70    }
71
72    /// Find a model by name. Returns error if not found.
73    pub async fn find_by_name(&self, name: &str) -> Result<Model, CoreError> {
74        self.get_by_name(name)
75            .await?
76            .ok_or_else(|| CoreError::Validation(format!("Model not found: {name}")))
77    }
78
79    /// Add a new model.
80    pub async fn add(&self, model: NewModel) -> Result<Model, CoreError> {
81        self.repo.insert(&model).await.map_err(CoreError::from)
82    }
83
84    /// Import a model from a local GGUF file with full metadata extraction.
85    ///
86    /// Validates file, parses GGUF metadata, detects capabilities, and registers
87    /// with rich metadata. This is the canonical way to add local models.
88    ///
89    /// # Arguments
90    ///
91    /// * `file_path` - Absolute path to the GGUF file
92    /// * `gguf_parser` - Parser implementation for metadata extraction
93    /// * `param_count_override` - Optional user override for parameter count
94    ///
95    /// # Returns
96    ///
97    /// Returns the registered `Model` with full metadata, or validation error.
98    ///
99    /// # Design
100    ///
101    /// This method orchestrates:
102    /// 1. File validation (existence, extension)
103    /// 2. GGUF metadata parsing (architecture, quantization, context)
104    /// 3. Capability detection (reasoning, tool-calling from metadata)
105    /// 4. Chat template inference (additional capability signals)
106    /// 5. Auto-tag generation from detected capabilities
107    /// 6. Model persistence with complete `NewModel` struct
108    pub async fn import_from_file(
109        &self,
110        file_path: &Path,
111        gguf_parser: &dyn GgufParserPort,
112        param_count_override: Option<f64>,
113    ) -> Result<Model, CoreError> {
114        // 1. Validate and parse GGUF file
115        let gguf_metadata = crate::utils::validation::validate_and_parse_gguf(
116            gguf_parser,
117            file_path
118                .to_str()
119                .ok_or_else(|| CoreError::Validation("Invalid file path encoding".to_string()))?,
120        )
121        .map_err(|e| CoreError::Validation(format!("GGUF validation failed: {e}")))?;
122
123        // 2. Resolve parameter count (override > metadata > 0.0 fallback)
124        let param_count_b = param_count_override
125            .or(gguf_metadata.param_count_b)
126            .unwrap_or(0.0);
127
128        // 3. Detect capabilities from GGUF metadata
129        let gguf_capabilities = gguf_parser.detect_capabilities(&gguf_metadata);
130        let auto_tags = gguf_capabilities.to_tags();
131
132        // 4. Infer additional capabilities from chat template
133        let template = gguf_metadata.metadata.get("tokenizer.chat_template");
134        let name = gguf_metadata.metadata.get("general.name");
135        let model_capabilities = crate::domain::infer_from_chat_template(
136            template.map(String::as_str),
137            name.map(String::as_str),
138        );
139
140        // 5. Construct fully-populated NewModel
141        let new_model = NewModel {
142            name: name.cloned().unwrap_or_else(|| {
143                file_path
144                    .file_stem()
145                    .and_then(|s| s.to_str())
146                    .unwrap_or("Unknown Model")
147                    .to_string()
148            }),
149            file_path: file_path.to_path_buf(),
150            param_count_b,
151            architecture: gguf_metadata.architecture,
152            quantization: gguf_metadata.quantization,
153            context_length: gguf_metadata.context_length,
154            metadata: gguf_metadata.metadata,
155            added_at: chrono::Utc::now(),
156            hf_repo_id: None,
157            hf_commit_sha: None,
158            hf_filename: None,
159            download_date: None,
160            last_update_check: None,
161            tags: auto_tags,
162            file_paths: None,
163            capabilities: model_capabilities,
164            inference_defaults: None,
165        };
166
167        // 6. Persist to repository
168        self.repo.insert(&new_model).await.map_err(CoreError::from)
169    }
170
171    /// Update a model.
172    pub async fn update(&self, model: &Model) -> Result<(), CoreError> {
173        self.repo.update(model).await.map_err(CoreError::from)
174    }
175
176    /// Delete a model by ID.
177    pub async fn delete(&self, id: i64) -> Result<(), CoreError> {
178        self.repo.delete(id).await.map_err(CoreError::from)
179    }
180
181    /// Remove a model by identifier. Returns the removed model.
182    pub async fn remove(&self, identifier: &str) -> Result<Model, CoreError> {
183        let model = self.find_by_identifier(identifier).await?;
184        self.repo.delete(model.id).await.map_err(CoreError::from)?;
185        Ok(model)
186    }
187
188    // ─────────────────────────────────────────────────────────────────────────
189    // Tag Operations
190    // ─────────────────────────────────────────────────────────────────────────
191
192    /// List all unique tags used across all models.
193    pub async fn list_tags(&self) -> Result<Vec<String>, CoreError> {
194        let models = self.repo.list().await.map_err(CoreError::from)?;
195        let mut all_tags = std::collections::HashSet::new();
196        for model in models {
197            for tag in model.tags {
198                all_tags.insert(tag);
199            }
200        }
201        let mut tags: Vec<String> = all_tags.into_iter().collect();
202        tags.sort();
203        Ok(tags)
204    }
205
206    /// Add a tag to a model.
207    ///
208    /// If the tag already exists on the model, this is a no-op.
209    pub async fn add_tag(&self, model_id: i64, tag: String) -> Result<(), CoreError> {
210        let mut model = self
211            .repo
212            .get_by_id(model_id)
213            .await
214            .map_err(CoreError::from)?;
215        if !model.tags.contains(&tag) {
216            model.tags.push(tag);
217            model.tags.sort();
218            self.repo.update(&model).await.map_err(CoreError::from)?;
219        }
220        Ok(())
221    }
222
223    /// Remove a tag from a model.
224    ///
225    /// If the tag doesn't exist on the model, this is a no-op.
226    pub async fn remove_tag(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
227        let mut model = self
228            .repo
229            .get_by_id(model_id)
230            .await
231            .map_err(CoreError::from)?;
232        model.tags.retain(|t| t != tag);
233        self.repo.update(&model).await.map_err(CoreError::from)?;
234        Ok(())
235    }
236
237    /// Get all tags for a specific model.
238    pub async fn get_tags(&self, model_id: i64) -> Result<Vec<String>, CoreError> {
239        let model = self
240            .repo
241            .get_by_id(model_id)
242            .await
243            .map_err(CoreError::from)?;
244        Ok(model.tags)
245    }
246
247    /// Get all models that have a specific tag.
248    pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Model>, CoreError> {
249        let models = self.repo.list().await.map_err(CoreError::from)?;
250        Ok(models
251            .into_iter()
252            .filter(|m| m.tags.contains(&tag.to_string()))
253            .collect())
254    }
255
256    // ─────────────────────────────────────────────────────────────────────────
257    // Filter/Aggregate Operations
258    // ─────────────────────────────────────────────────────────────────────────
259
260    /// Get filter options aggregated from all models.
261    ///
262    /// Returns distinct quantizations, parameter count range, and context length range
263    /// for use in the GUI filter popover.
264    ///
265    /// Note: Uses in-memory aggregation for simplicity. This is acceptable for typical
266    /// model libraries (<100 models). Revisit if libraries grow large.
267    pub async fn get_filter_options(&self) -> Result<crate::domain::ModelFilterOptions, CoreError> {
268        use crate::domain::{ModelFilterOptions, RangeValues};
269        use std::collections::HashSet;
270
271        let models = self.repo.list().await.map_err(CoreError::from)?;
272
273        // Collect distinct quantizations
274        let mut quantizations: Vec<String> = models
275            .iter()
276            .filter_map(|m| m.quantization.clone())
277            .filter(|q| !q.is_empty())
278            .collect::<HashSet<_>>()
279            .into_iter()
280            .collect();
281        quantizations.sort();
282
283        // Compute param_count_b range
284        let param_range = if models.is_empty() {
285            None
286        } else {
287            let min = models
288                .iter()
289                .map(|m| m.param_count_b)
290                .fold(f64::INFINITY, f64::min);
291            let max = models
292                .iter()
293                .map(|m| m.param_count_b)
294                .fold(f64::NEG_INFINITY, f64::max);
295            if min.is_finite() && max.is_finite() {
296                Some(RangeValues { min, max })
297            } else {
298                None
299            }
300        };
301
302        // Compute context_length range (only models with context_length set)
303        let context_lengths: Vec<u64> = models.iter().filter_map(|m| m.context_length).collect();
304        #[allow(clippy::cast_precision_loss)]
305        let context_range = if context_lengths.is_empty() {
306            None
307        } else {
308            let min = *context_lengths.iter().min().unwrap() as f64;
309            let max = *context_lengths.iter().max().unwrap() as f64;
310            Some(RangeValues { min, max })
311        };
312
313        Ok(ModelFilterOptions {
314            quantizations,
315            param_range,
316            context_range,
317        })
318    }
319
320    // ─────────────────────────────────────────────────────────────────────────
321    // Capability Bootstrap
322    // ─────────────────────────────────────────────────────────────────────────
323
324    /// Backfill capabilities for models that don't have them set.
325    ///
326    /// This runs on startup to handle models with unknown capabilities.
327    /// Only infers if capabilities are empty (0/unknown).
328    ///
329    /// # INVARIANT
330    ///
331    /// Never overwrite explicitly-set capabilities. Only infer when unknown.
332    pub async fn bootstrap_capabilities(&self) -> Result<(), CoreError> {
333        use crate::domain::infer_from_chat_template;
334
335        let models = self.repo.list().await.map_err(CoreError::from)?;
336
337        for mut model in models {
338            // Only infer if capabilities are unknown (empty)
339            if model.capabilities.is_empty() {
340                let template = model.metadata.get("tokenizer.chat_template");
341                let name = model.metadata.get("general.name");
342                let inferred = infer_from_chat_template(
343                    template.map(String::as_str),
344                    name.map(String::as_str),
345                );
346
347                model.capabilities = inferred;
348                self.repo.update(&model).await.map_err(CoreError::from)?;
349            }
350        }
351
352        Ok(())
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use crate::ports::{ModelRepository, RepositoryError};
360    use async_trait::async_trait;
361    use chrono::Utc;
362
363    use std::path::PathBuf;
364    use std::sync::Mutex;
365
366    struct MockRepo {
367        models: Mutex<Vec<Model>>,
368    }
369
370    impl MockRepo {
371        fn new() -> Self {
372            Self {
373                models: Mutex::new(vec![]),
374            }
375        }
376    }
377
378    #[async_trait]
379    impl ModelRepository for MockRepo {
380        async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
381            Ok(self.models.lock().unwrap().clone())
382        }
383
384        async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
385            self.models
386                .lock()
387                .unwrap()
388                .iter()
389                .find(|m| m.id == id)
390                .cloned()
391                .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
392        }
393
394        async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
395            self.models
396                .lock()
397                .unwrap()
398                .iter()
399                .find(|m| m.name == name)
400                .cloned()
401                .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
402        }
403
404        #[allow(clippy::cast_possible_wrap, clippy::significant_drop_tightening)]
405        async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
406            let mut models = self.models.lock().unwrap();
407            let id = models.len() as i64 + 1;
408            let created = Model {
409                id,
410                name: model.name.clone(),
411                file_path: model.file_path.clone(),
412                param_count_b: model.param_count_b,
413                architecture: model.architecture.clone(),
414                quantization: model.quantization.clone(),
415                context_length: model.context_length,
416                metadata: model.metadata.clone(),
417                added_at: model.added_at,
418                hf_repo_id: model.hf_repo_id.clone(),
419                hf_commit_sha: model.hf_commit_sha.clone(),
420                hf_filename: model.hf_filename.clone(),
421                download_date: model.download_date,
422                last_update_check: model.last_update_check,
423                tags: model.tags.clone(),
424                capabilities: model.capabilities,
425                inference_defaults: model.inference_defaults.clone(),
426            };
427            models.push(created.clone());
428            Ok(created)
429        }
430
431        async fn update(&self, model: &Model) -> Result<(), RepositoryError> {
432            let mut models = self.models.lock().unwrap();
433            models.iter_mut().find(|m| m.id == model.id).map_or_else(
434                || Err(RepositoryError::NotFound(format!("id={}", model.id))),
435                |m| {
436                    m.clone_from(model);
437                    Ok(())
438                },
439            )
440        }
441
442        async fn delete(&self, id: i64) -> Result<(), RepositoryError> {
443            let mut models = self.models.lock().unwrap();
444            let len_before = models.len();
445            models.retain(|m| m.id != id);
446            if models.len() == len_before {
447                Err(RepositoryError::NotFound(format!("id={id}")))
448            } else {
449                Ok(())
450            }
451        }
452    }
453
454    #[tokio::test]
455    async fn test_list_empty() {
456        let repo = Arc::new(MockRepo::new());
457        let service = ModelService::new(repo);
458        let models = service.list().await.unwrap();
459        assert!(models.is_empty());
460    }
461
462    #[tokio::test]
463    async fn test_add_and_get() {
464        let repo = Arc::new(MockRepo::new());
465        let service = ModelService::new(repo);
466
467        let new_model = NewModel::new(
468            "test-model".to_string(),
469            PathBuf::from("/path/to/model.gguf"),
470            7.0,
471            Utc::now(),
472        );
473
474        let created = service.add(new_model).await.unwrap();
475        assert_eq!(created.name, "test-model");
476
477        let found = service.get_by_name("test-model").await.unwrap();
478        assert!(found.is_some());
479        assert_eq!(found.unwrap().id, created.id);
480    }
481
482    #[tokio::test]
483    async fn test_find_by_identifier_not_found() {
484        let repo = Arc::new(MockRepo::new());
485        let service = ModelService::new(repo);
486
487        let result = service.find_by_identifier("nonexistent").await;
488        assert!(result.is_err());
489    }
490
491    #[tokio::test]
492    async fn test_get_filter_options_empty() {
493        let repo = Arc::new(MockRepo::new());
494        let service = ModelService::new(repo);
495
496        let options = service.get_filter_options().await.unwrap();
497        assert!(options.quantizations.is_empty());
498        assert!(options.param_range.is_none());
499        assert!(options.context_range.is_none());
500    }
501
502    #[tokio::test]
503    async fn test_get_filter_options_with_models() {
504        let repo = Arc::new(MockRepo::new());
505        let service = ModelService::new(repo);
506
507        // Add models with different characteristics
508        let mut model1 = NewModel::new(
509            "model-1".to_string(),
510            PathBuf::from("/path/to/model1.gguf"),
511            7.0,
512            Utc::now(),
513        );
514        model1.quantization = Some("Q4_K_M".to_string());
515        model1.context_length = Some(4096);
516
517        let mut model2 = NewModel::new(
518            "model-2".to_string(),
519            PathBuf::from("/path/to/model2.gguf"),
520            13.0,
521            Utc::now(),
522        );
523        model2.quantization = Some("Q8_0".to_string());
524        model2.context_length = Some(8192);
525
526        let mut model3 = NewModel::new(
527            "model-3".to_string(),
528            PathBuf::from("/path/to/model3.gguf"),
529            70.0,
530            Utc::now(),
531        );
532        model3.quantization = Some("Q4_K_M".to_string()); // Duplicate quant
533        // No context_length set
534
535        service.add(model1).await.unwrap();
536        service.add(model2).await.unwrap();
537        service.add(model3).await.unwrap();
538
539        let options = service.get_filter_options().await.unwrap();
540
541        // Should have 2 distinct quantizations, sorted
542        assert_eq!(options.quantizations, vec!["Q4_K_M", "Q8_0"]);
543
544        // Param range: 7.0 to 70.0
545        let param_range = options.param_range.unwrap();
546        assert!((param_range.min - 7.0).abs() < 0.001);
547        assert!((param_range.max - 70.0).abs() < 0.001);
548
549        // Context range: 4096 to 8192 (model3 has no context)
550        let context_range = options.context_range.unwrap();
551        assert!((context_range.min - 4096.0).abs() < 0.001);
552        assert!((context_range.max - 8192.0).abs() < 0.001);
553    }
554}