gglib_core/services/
model_service.rs

1//! Model service - orchestrates model CRUD operations.
2
3use crate::domain::{Model, NewModel};
4use crate::ports::{CoreError, GgufParserPort, ModelRepository, RepositoryError};
5use std::path::Path;
6use std::sync::Arc;
7
8/// Service for model operations.
9///
10/// This service provides high-level model management by delegating
11/// to the injected `ModelRepository`. It adds no business logic
12/// beyond what the repository provides - it's a thin facade.
13pub struct ModelService {
14    repo: Arc<dyn ModelRepository>,
15}
16
17impl ModelService {
18    /// Create a new model service with the given repository.
19    pub fn new(repo: Arc<dyn ModelRepository>) -> Self {
20        Self { repo }
21    }
22
23    /// List all models.
24    pub async fn list(&self) -> Result<Vec<Model>, CoreError> {
25        self.repo.list().await.map_err(CoreError::from)
26    }
27
28    /// Get a model by its identifier (id, name, or HF ID).
29    pub async fn get(&self, identifier: &str) -> Result<Option<Model>, CoreError> {
30        // Try by ID first
31        if let Ok(id) = identifier.parse::<i64>() {
32            match self.repo.get_by_id(id).await {
33                Ok(model) => return Ok(Some(model)),
34                Err(RepositoryError::NotFound(_)) => {}
35                Err(e) => return Err(CoreError::from(e)),
36            }
37        }
38        // Try by name
39        match self.repo.get_by_name(identifier).await {
40            Ok(model) => Ok(Some(model)),
41            Err(RepositoryError::NotFound(_)) => Ok(None),
42            Err(e) => Err(CoreError::from(e)),
43        }
44    }
45
46    /// Get a model by its database ID.
47    pub async fn get_by_id(&self, id: i64) -> Result<Option<Model>, CoreError> {
48        match self.repo.get_by_id(id).await {
49            Ok(model) => Ok(Some(model)),
50            Err(RepositoryError::NotFound(_)) => Ok(None),
51            Err(e) => Err(CoreError::from(e)),
52        }
53    }
54
55    /// Get a model by name.
56    pub async fn get_by_name(&self, name: &str) -> Result<Option<Model>, CoreError> {
57        match self.repo.get_by_name(name).await {
58            Ok(model) => Ok(Some(model)),
59            Err(RepositoryError::NotFound(_)) => Ok(None),
60            Err(e) => Err(CoreError::from(e)),
61        }
62    }
63
64    /// Find a model by identifier (id, name, or HF ID).
65    /// Returns error if not found.
66    pub async fn find_by_identifier(&self, identifier: &str) -> Result<Model, CoreError> {
67        self.get(identifier)
68            .await?
69            .ok_or_else(|| CoreError::Validation(format!("Model not found: {identifier}")))
70    }
71
72    /// Find a model by name. Returns error if not found.
73    pub async fn find_by_name(&self, name: &str) -> Result<Model, CoreError> {
74        self.get_by_name(name)
75            .await?
76            .ok_or_else(|| CoreError::Validation(format!("Model not found: {name}")))
77    }
78
79    /// Add a new model.
80    pub async fn add(&self, model: NewModel) -> Result<Model, CoreError> {
81        self.repo.insert(&model).await.map_err(CoreError::from)
82    }
83
84    /// Import a model from a local GGUF file with full metadata extraction.
85    ///
86    /// Validates file, parses GGUF metadata, detects capabilities, and registers
87    /// with rich metadata. This is the canonical way to add local models.
88    ///
89    /// # Arguments
90    ///
91    /// * `file_path` - Absolute path to the GGUF file
92    /// * `gguf_parser` - Parser implementation for metadata extraction
93    /// * `param_count_override` - Optional user override for parameter count
94    ///
95    /// # Returns
96    ///
97    /// Returns the registered `Model` with full metadata, or validation error.
98    ///
99    /// # Design
100    ///
101    /// This method orchestrates:
102    /// 1. File validation (existence, extension)
103    /// 2. GGUF metadata parsing (architecture, quantization, context)
104    /// 3. Capability detection (reasoning, tool-calling from metadata)
105    /// 4. Chat template inference (additional capability signals)
106    /// 5. Auto-tag generation from detected capabilities
107    /// 6. Model persistence with complete `NewModel` struct
108    pub async fn import_from_file(
109        &self,
110        file_path: &Path,
111        gguf_parser: &dyn GgufParserPort,
112        param_count_override: Option<f64>,
113    ) -> Result<Model, CoreError> {
114        // 1. Validate and parse GGUF file
115        let gguf_metadata = crate::utils::validation::validate_and_parse_gguf(
116            gguf_parser,
117            file_path
118                .to_str()
119                .ok_or_else(|| CoreError::Validation("Invalid file path encoding".to_string()))?,
120        )
121        .map_err(|e| CoreError::Validation(format!("GGUF validation failed: {e}")))?;
122
123        // 2. Resolve parameter count (override > metadata > 0.0 fallback)
124        let param_count_b = param_count_override
125            .or(gguf_metadata.param_count_b)
126            .unwrap_or(0.0);
127
128        // 3. Detect capabilities from GGUF metadata
129        let gguf_capabilities = gguf_parser.detect_capabilities(&gguf_metadata);
130        let auto_tags = gguf_capabilities.to_tags();
131
132        // 4. Infer additional capabilities from chat template
133        let template = gguf_metadata.metadata.get("tokenizer.chat_template");
134        let name = gguf_metadata.metadata.get("general.name");
135        let model_capabilities = crate::domain::infer_from_chat_template(
136            template.map(String::as_str),
137            name.map(String::as_str),
138        );
139
140        // 5. Construct fully-populated NewModel
141        let new_model = NewModel {
142            name: name.cloned().unwrap_or_else(|| {
143                file_path
144                    .file_stem()
145                    .and_then(|s| s.to_str())
146                    .unwrap_or("Unknown Model")
147                    .to_string()
148            }),
149            file_path: file_path.to_path_buf(),
150            param_count_b,
151            architecture: gguf_metadata.architecture,
152            quantization: gguf_metadata.quantization,
153            context_length: gguf_metadata.context_length,
154            expert_count: gguf_metadata.expert_count,
155            expert_used_count: gguf_metadata.expert_used_count,
156            expert_shared_count: gguf_metadata.expert_shared_count,
157            metadata: gguf_metadata.metadata,
158            added_at: chrono::Utc::now(),
159            hf_repo_id: None,
160            hf_commit_sha: None,
161            hf_filename: None,
162            download_date: None,
163            last_update_check: None,
164            tags: auto_tags,
165            file_paths: None,
166            capabilities: model_capabilities,
167            inference_defaults: None,
168        };
169
170        // 6. Persist to repository
171        self.repo.insert(&new_model).await.map_err(CoreError::from)
172    }
173
174    /// Update a model.
175    pub async fn update(&self, model: &Model) -> Result<(), CoreError> {
176        self.repo.update(model).await.map_err(CoreError::from)
177    }
178
179    /// Delete a model by ID.
180    pub async fn delete(&self, id: i64) -> Result<(), CoreError> {
181        self.repo.delete(id).await.map_err(CoreError::from)
182    }
183
184    /// Remove a model by identifier. Returns the removed model.
185    pub async fn remove(&self, identifier: &str) -> Result<Model, CoreError> {
186        let model = self.find_by_identifier(identifier).await?;
187        self.repo.delete(model.id).await.map_err(CoreError::from)?;
188        Ok(model)
189    }
190
191    // ─────────────────────────────────────────────────────────────────────────
192    // Tag Operations
193    // ─────────────────────────────────────────────────────────────────────────
194
195    /// List all unique tags used across all models.
196    pub async fn list_tags(&self) -> Result<Vec<String>, CoreError> {
197        let models = self.repo.list().await.map_err(CoreError::from)?;
198        let mut all_tags = std::collections::HashSet::new();
199        for model in models {
200            for tag in model.tags {
201                all_tags.insert(tag);
202            }
203        }
204        let mut tags: Vec<String> = all_tags.into_iter().collect();
205        tags.sort();
206        Ok(tags)
207    }
208
209    /// Add a tag to a model.
210    ///
211    /// If the tag already exists on the model, this is a no-op.
212    pub async fn add_tag(&self, model_id: i64, tag: String) -> Result<(), CoreError> {
213        let mut model = self
214            .repo
215            .get_by_id(model_id)
216            .await
217            .map_err(CoreError::from)?;
218        if !model.tags.contains(&tag) {
219            model.tags.push(tag);
220            model.tags.sort();
221            self.repo.update(&model).await.map_err(CoreError::from)?;
222        }
223        Ok(())
224    }
225
226    /// Remove a tag from a model.
227    ///
228    /// If the tag doesn't exist on the model, this is a no-op.
229    pub async fn remove_tag(&self, model_id: i64, tag: &str) -> Result<(), CoreError> {
230        let mut model = self
231            .repo
232            .get_by_id(model_id)
233            .await
234            .map_err(CoreError::from)?;
235        model.tags.retain(|t| t != tag);
236        self.repo.update(&model).await.map_err(CoreError::from)?;
237        Ok(())
238    }
239
240    /// Get all tags for a specific model.
241    pub async fn get_tags(&self, model_id: i64) -> Result<Vec<String>, CoreError> {
242        let model = self
243            .repo
244            .get_by_id(model_id)
245            .await
246            .map_err(CoreError::from)?;
247        Ok(model.tags)
248    }
249
250    /// Get all models that have a specific tag.
251    pub async fn get_by_tag(&self, tag: &str) -> Result<Vec<Model>, CoreError> {
252        let models = self.repo.list().await.map_err(CoreError::from)?;
253        Ok(models
254            .into_iter()
255            .filter(|m| m.tags.contains(&tag.to_string()))
256            .collect())
257    }
258
259    // ─────────────────────────────────────────────────────────────────────────
260    // Filter/Aggregate Operations
261    // ─────────────────────────────────────────────────────────────────────────
262
263    /// Get filter options aggregated from all models.
264    ///
265    /// Returns distinct quantizations, parameter count range, and context length range
266    /// for use in the GUI filter popover.
267    ///
268    /// Note: Uses in-memory aggregation for simplicity. This is acceptable for typical
269    /// model libraries (<100 models). Revisit if libraries grow large.
270    pub async fn get_filter_options(&self) -> Result<crate::domain::ModelFilterOptions, CoreError> {
271        use crate::domain::{ModelFilterOptions, RangeValues};
272        use std::collections::HashSet;
273
274        let models = self.repo.list().await.map_err(CoreError::from)?;
275
276        // Collect distinct quantizations
277        let mut quantizations: Vec<String> = models
278            .iter()
279            .filter_map(|m| m.quantization.clone())
280            .filter(|q| !q.is_empty())
281            .collect::<HashSet<_>>()
282            .into_iter()
283            .collect();
284        quantizations.sort();
285
286        // Compute param_count_b range
287        let param_range = if models.is_empty() {
288            None
289        } else {
290            let min = models
291                .iter()
292                .map(|m| m.param_count_b)
293                .fold(f64::INFINITY, f64::min);
294            let max = models
295                .iter()
296                .map(|m| m.param_count_b)
297                .fold(f64::NEG_INFINITY, f64::max);
298            if min.is_finite() && max.is_finite() {
299                Some(RangeValues { min, max })
300            } else {
301                None
302            }
303        };
304
305        // Compute context_length range (only models with context_length set)
306        let context_lengths: Vec<u64> = models.iter().filter_map(|m| m.context_length).collect();
307        #[allow(clippy::cast_precision_loss)]
308        let context_range = if context_lengths.is_empty() {
309            None
310        } else {
311            let min = *context_lengths.iter().min().unwrap() as f64;
312            let max = *context_lengths.iter().max().unwrap() as f64;
313            Some(RangeValues { min, max })
314        };
315
316        Ok(ModelFilterOptions {
317            quantizations,
318            param_range,
319            context_range,
320        })
321    }
322
323    // ─────────────────────────────────────────────────────────────────────────
324    // Capability Bootstrap
325    // ─────────────────────────────────────────────────────────────────────────
326
327    /// Backfill capabilities for models that don't have them set.
328    ///
329    /// This runs on startup to handle models with unknown capabilities.
330    /// Only infers if capabilities are empty (0/unknown).
331    ///
332    /// # INVARIANT
333    ///
334    /// Never overwrite explicitly-set capabilities. Only infer when unknown.
335    pub async fn bootstrap_capabilities(&self) -> Result<(), CoreError> {
336        use crate::domain::infer_from_chat_template;
337
338        let models = self.repo.list().await.map_err(CoreError::from)?;
339
340        for mut model in models {
341            // Only infer if capabilities are unknown (empty)
342            if model.capabilities.is_empty() {
343                let template = model.metadata.get("tokenizer.chat_template");
344                let name = model.metadata.get("general.name");
345                let inferred = infer_from_chat_template(
346                    template.map(String::as_str),
347                    name.map(String::as_str),
348                );
349
350                model.capabilities = inferred;
351                self.repo.update(&model).await.map_err(CoreError::from)?;
352            }
353        }
354
355        Ok(())
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate::ports::{ModelRepository, RepositoryError};
363    use async_trait::async_trait;
364    use chrono::Utc;
365
366    use std::path::PathBuf;
367    use std::sync::Mutex;
368
369    struct MockRepo {
370        models: Mutex<Vec<Model>>,
371    }
372
373    impl MockRepo {
374        fn new() -> Self {
375            Self {
376                models: Mutex::new(vec![]),
377            }
378        }
379    }
380
381    #[async_trait]
382    impl ModelRepository for MockRepo {
383        async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
384            Ok(self.models.lock().unwrap().clone())
385        }
386
387        async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
388            self.models
389                .lock()
390                .unwrap()
391                .iter()
392                .find(|m| m.id == id)
393                .cloned()
394                .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
395        }
396
397        async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
398            self.models
399                .lock()
400                .unwrap()
401                .iter()
402                .find(|m| m.name == name)
403                .cloned()
404                .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
405        }
406
407        #[allow(clippy::cast_possible_wrap, clippy::significant_drop_tightening)]
408        async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
409            let mut models = self.models.lock().unwrap();
410            let id = models.len() as i64 + 1;
411            let created = Model {
412                id,
413                name: model.name.clone(),
414                file_path: model.file_path.clone(),
415                param_count_b: model.param_count_b,
416                architecture: model.architecture.clone(),
417                quantization: model.quantization.clone(),
418                context_length: model.context_length,
419                expert_count: model.expert_count,
420                expert_used_count: model.expert_used_count,
421                expert_shared_count: model.expert_shared_count,
422                metadata: model.metadata.clone(),
423                added_at: model.added_at,
424                hf_repo_id: model.hf_repo_id.clone(),
425                hf_commit_sha: model.hf_commit_sha.clone(),
426                hf_filename: model.hf_filename.clone(),
427                download_date: model.download_date,
428                last_update_check: model.last_update_check,
429                tags: model.tags.clone(),
430                capabilities: model.capabilities,
431                inference_defaults: model.inference_defaults.clone(),
432            };
433            models.push(created.clone());
434            Ok(created)
435        }
436
437        async fn update(&self, model: &Model) -> Result<(), RepositoryError> {
438            let mut models = self.models.lock().unwrap();
439            models.iter_mut().find(|m| m.id == model.id).map_or_else(
440                || Err(RepositoryError::NotFound(format!("id={}", model.id))),
441                |m| {
442                    m.clone_from(model);
443                    Ok(())
444                },
445            )
446        }
447
448        async fn delete(&self, id: i64) -> Result<(), RepositoryError> {
449            let mut models = self.models.lock().unwrap();
450            let len_before = models.len();
451            models.retain(|m| m.id != id);
452            if models.len() == len_before {
453                Err(RepositoryError::NotFound(format!("id={id}")))
454            } else {
455                Ok(())
456            }
457        }
458    }
459
460    #[tokio::test]
461    async fn test_list_empty() {
462        let repo = Arc::new(MockRepo::new());
463        let service = ModelService::new(repo);
464        let models = service.list().await.unwrap();
465        assert!(models.is_empty());
466    }
467
468    #[tokio::test]
469    async fn test_add_and_get() {
470        let repo = Arc::new(MockRepo::new());
471        let service = ModelService::new(repo);
472
473        let new_model = NewModel::new(
474            "test-model".to_string(),
475            PathBuf::from("/path/to/model.gguf"),
476            7.0,
477            Utc::now(),
478        );
479
480        let created = service.add(new_model).await.unwrap();
481        assert_eq!(created.name, "test-model");
482
483        let found = service.get_by_name("test-model").await.unwrap();
484        assert!(found.is_some());
485        assert_eq!(found.unwrap().id, created.id);
486    }
487
488    #[tokio::test]
489    async fn test_find_by_identifier_not_found() {
490        let repo = Arc::new(MockRepo::new());
491        let service = ModelService::new(repo);
492
493        let result = service.find_by_identifier("nonexistent").await;
494        assert!(result.is_err());
495    }
496
497    #[tokio::test]
498    async fn test_get_filter_options_empty() {
499        let repo = Arc::new(MockRepo::new());
500        let service = ModelService::new(repo);
501
502        let options = service.get_filter_options().await.unwrap();
503        assert!(options.quantizations.is_empty());
504        assert!(options.param_range.is_none());
505        assert!(options.context_range.is_none());
506    }
507
508    #[tokio::test]
509    async fn test_get_filter_options_with_models() {
510        let repo = Arc::new(MockRepo::new());
511        let service = ModelService::new(repo);
512
513        // Add models with different characteristics
514        let mut model1 = NewModel::new(
515            "model-1".to_string(),
516            PathBuf::from("/path/to/model1.gguf"),
517            7.0,
518            Utc::now(),
519        );
520        model1.quantization = Some("Q4_K_M".to_string());
521        model1.context_length = Some(4096);
522
523        let mut model2 = NewModel::new(
524            "model-2".to_string(),
525            PathBuf::from("/path/to/model2.gguf"),
526            13.0,
527            Utc::now(),
528        );
529        model2.quantization = Some("Q8_0".to_string());
530        model2.context_length = Some(8192);
531
532        let mut model3 = NewModel::new(
533            "model-3".to_string(),
534            PathBuf::from("/path/to/model3.gguf"),
535            70.0,
536            Utc::now(),
537        );
538        model3.quantization = Some("Q4_K_M".to_string()); // Duplicate quant
539        // No context_length set
540
541        service.add(model1).await.unwrap();
542        service.add(model2).await.unwrap();
543        service.add(model3).await.unwrap();
544
545        let options = service.get_filter_options().await.unwrap();
546
547        // Should have 2 distinct quantizations, sorted
548        assert_eq!(options.quantizations, vec!["Q4_K_M", "Q8_0"]);
549
550        // Param range: 7.0 to 70.0
551        let param_range = options.param_range.unwrap();
552        assert!((param_range.min - 7.0).abs() < 0.001);
553        assert!((param_range.max - 70.0).abs() < 0.001);
554
555        // Context range: 4096 to 8192 (model3 has no context)
556        let context_range = options.context_range.unwrap();
557        assert!((context_range.min - 4096.0).abs() < 0.001);
558        assert!((context_range.max - 8192.0).abs() < 0.001);
559    }
560}