gglib_core/services/
model_registrar.rs

1//! Model registrar service implementation.
2//!
3//! This service implements `ModelRegistrarPort` using the `ModelRepository`
4//! and `GgufParserPort` dependencies. It's used by the download manager
5//! to register completed downloads.
6
7use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel, NewModelFile};
14use crate::download::Quantization;
15use crate::ports::{
16    CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19/// Repository trait for model files metadata.
20///
21/// We don't depend on `gglib_db` directly - adapters inject the implementation.
22/// This type is re-exported from `gglib_db` for use in adapters.
23#[async_trait]
24pub trait ModelFilesRepositoryPort: Send + Sync {
25    /// Insert a new model file record.
26    async fn insert(&self, model_file: &NewModelFile) -> anyhow::Result<()>;
27}
28
29/// Implementation of the model registrar port.
30///
31/// This service composes over `ModelRepository` for persistence and
32/// `GgufParserPort` for metadata extraction.
33pub struct ModelRegistrar {
34    /// Repository for persisting models.
35    model_repo: Arc<dyn ModelRepository>,
36    /// Parser for extracting GGUF metadata.
37    gguf_parser: Arc<dyn GgufParserPort>,
38    /// Repository for persisting model file metadata.
39    model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
40}
41
42impl ModelRegistrar {
43    /// Create a new model registrar.
44    ///
45    /// # Arguments
46    ///
47    /// * `model_repo` - Repository for persisting models
48    /// * `gguf_parser` - Parser for extracting GGUF metadata
49    /// * `model_files_repo` - Optional repository for persisting model file metadata
50    pub fn new(
51        model_repo: Arc<dyn ModelRepository>,
52        gguf_parser: Arc<dyn GgufParserPort>,
53        model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
54    ) -> Self {
55        Self {
56            model_repo,
57            gguf_parser,
58            model_files_repo,
59        }
60    }
61
62    /// Filter `HuggingFace` tags using a blocklist.
63    ///
64    /// Removes noisy tags like `gguf`, `arxiv:*`, `region:*`, `license:*`, `dataset:*`.
65    fn filter_hf_tags(tags: &[String]) -> Vec<String> {
66        tags.iter()
67            .filter(|tag| {
68                let tag_lower = tag.to_lowercase();
69                !tag_lower.starts_with("arxiv:")
70                    && !tag_lower.starts_with("region:")
71                    && !tag_lower.starts_with("license:")
72                    && !tag_lower.starts_with("dataset:")
73                    && tag_lower != "gguf"
74            })
75            .cloned()
76            .collect()
77    }
78
79    /// Merge GGUF-derived tags with filtered HF tags, removing duplicates.
80    ///
81    /// GGUF-derived tags are prioritized (appear first in the result).
82    fn merge_tags(gguf_tags: Vec<String>, hf_tags: &[String]) -> Vec<String> {
83        use std::collections::HashSet;
84
85        let mut seen = HashSet::new();
86        let mut result = Vec::new();
87
88        // Add GGUF tags first
89        for tag in gguf_tags {
90            if seen.insert(tag.clone()) {
91                result.push(tag);
92            }
93        }
94
95        // Add filtered HF tags
96        for tag in Self::filter_hf_tags(hf_tags) {
97            if seen.insert(tag.clone()) {
98                result.push(tag);
99            }
100        }
101
102        result
103    }
104}
105
106#[async_trait]
107impl ModelRegistrarPort for ModelRegistrar {
108    async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
109        let file_path = download.db_path();
110
111        // Parse GGUF metadata from the downloaded file
112        let gguf_metadata = self.gguf_parser.parse(file_path).ok();
113
114        // Extract param_count_b from metadata, fall back to 0.0
115        let param_count_b = gguf_metadata
116            .as_ref()
117            .and_then(|m| m.param_count_b)
118            .unwrap_or(0.0);
119
120        let mut model = NewModel::new(
121            download.repo_id.clone(),
122            file_path.to_path_buf(),
123            param_count_b,
124            Utc::now(),
125        );
126
127        // Use extracted metadata where available, with fallbacks
128        model.quantization = gguf_metadata
129            .as_ref()
130            .and_then(|m| m.quantization.clone())
131            .or_else(|| Some(download.quantization.to_string()));
132        model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
133        model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
134        model.expert_count = gguf_metadata.as_ref().and_then(|m| m.expert_count);
135        model.expert_used_count = gguf_metadata.as_ref().and_then(|m| m.expert_used_count);
136        model.expert_shared_count = gguf_metadata.as_ref().and_then(|m| m.expert_shared_count);
137        if let Some(ref meta) = gguf_metadata {
138            model.metadata.clone_from(&meta.metadata);
139        }
140        model.hf_repo_id = Some(download.repo_id.clone());
141        model.hf_commit_sha = Some(download.commit_sha.clone());
142        model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
143        model.download_date = Some(Utc::now());
144
145        // Pass through file_paths for sharded models
146        model.file_paths.clone_from(&download.file_paths);
147
148        // Auto-detect capabilities from metadata and merge with HF tags
149        let gguf_tags = gguf_metadata.as_ref().map_or_else(Vec::new, |meta| {
150            let capabilities = self.gguf_parser.detect_capabilities(meta);
151            capabilities.to_tags()
152        });
153
154        // Merge GGUF-derived tags with filtered HF tags (deduplicated)
155        model.tags = Self::merge_tags(gguf_tags, &download.hf_tags);
156
157        // Apply tag-based inference defaults for reasoning models.
158        // Only set when the model has no explicit defaults already — ensures
159        // that user-curated defaults are never clobbered by re-registration.
160        if model.inference_defaults.is_none()
161            && model
162                .tags
163                .iter()
164                .any(|t| t.eq_ignore_ascii_case("reasoning"))
165        {
166            model.inference_defaults = Some(crate::domain::InferenceConfig::reasoning_profile());
167        }
168
169        // Infer capabilities from chat template OR architecture — OR'd so
170        // either signal is sufficient.  Architecture is the backstop for models
171        // whose GGUF ships without a tokenizer section.
172        let template = model.metadata.get("tokenizer.chat_template");
173        let name = model.metadata.get("general.name");
174        let arch = model.metadata.get("general.architecture");
175        let from_template = crate::domain::infer_from_chat_template(
176            template.map(String::as_str),
177            name.map(String::as_str),
178        );
179        let from_arch = crate::domain::capabilities_from_architecture(arch.map(String::as_str));
180        model.capabilities = from_template | from_arch;
181
182        let registered = self.model_repo.insert(&model).await?;
183
184        // Insert model_files records with OIDs for each shard (if repo is available)
185        if let Some(ref repo) = self.model_files_repo {
186            for (file_index, file_entry) in download.hf_file_entries.iter().enumerate() {
187                if let Some(size) = file_entry.size {
188                    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
189                    let model_file = NewModelFile::new(
190                        registered.id,
191                        file_entry.path.clone(),
192                        file_index as i32,
193                        size as i64,
194                        file_entry.oid.clone(),
195                    );
196
197                    if let Err(e) = repo.insert(&model_file).await {
198                        // Soft fail - log but don't propagate error
199                        tracing::warn!(
200                            model_id = registered.id,
201                            file_path = %file_entry.path,
202                            error = %e,
203                            "Failed to insert model_files record - verification features may be unavailable"
204                        );
205                    }
206                }
207            }
208        }
209
210        Ok(registered)
211    }
212
213    async fn register_model_from_path(
214        &self,
215        repo_id: &str,
216        commit_sha: &str,
217        file_path: &Path,
218        quantization: &str,
219    ) -> Result<Model, RepositoryError> {
220        let download = CompletedDownload {
221            primary_path: file_path.to_path_buf(),
222            all_paths: vec![file_path.to_path_buf()],
223            quantization: Quantization::from_filename(quantization),
224            repo_id: repo_id.to_string(),
225            commit_sha: commit_sha.to_string(),
226            is_sharded: false,
227            total_bytes: 0,
228            file_paths: None,
229            hf_tags: vec![],
230            hf_file_entries: vec![],
231        };
232
233        self.register_model(&download).await
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::domain::Model;
241    use crate::ports::NoopGgufParser;
242    use std::path::PathBuf;
243    use std::sync::Mutex;
244
245    /// Mock model repository for testing.
246    struct MockModelRepo {
247        models: Mutex<Vec<Model>>,
248        next_id: Mutex<i64>,
249    }
250
251    impl MockModelRepo {
252        fn new() -> Self {
253            Self {
254                models: Mutex::new(Vec::new()),
255                next_id: Mutex::new(1),
256            }
257        }
258    }
259
260    #[async_trait]
261    impl ModelRepository for MockModelRepo {
262        async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
263            Ok(self.models.lock().unwrap().clone())
264        }
265
266        async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
267            self.models
268                .lock()
269                .unwrap()
270                .iter()
271                .find(|m| m.id == id)
272                .cloned()
273                .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
274        }
275
276        async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
277            self.models
278                .lock()
279                .unwrap()
280                .iter()
281                .find(|m| m.name == name)
282                .cloned()
283                .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
284        }
285
286        async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
287            let mut id = self.next_id.lock().unwrap();
288            let persisted = Model {
289                id: *id,
290                name: model.name.clone(),
291                file_path: model.file_path.clone(),
292                param_count_b: model.param_count_b,
293                architecture: model.architecture.clone(),
294                quantization: model.quantization.clone(),
295                context_length: model.context_length,
296                expert_count: model.expert_count,
297                expert_used_count: model.expert_used_count,
298                expert_shared_count: model.expert_shared_count,
299                metadata: model.metadata.clone(),
300                added_at: model.added_at,
301                hf_repo_id: model.hf_repo_id.clone(),
302                hf_commit_sha: model.hf_commit_sha.clone(),
303                hf_filename: model.hf_filename.clone(),
304                capabilities: model.capabilities,
305                download_date: model.download_date,
306                last_update_check: model.last_update_check,
307                tags: model.tags.clone(),
308                inference_defaults: model.inference_defaults.clone(),
309            };
310            *id += 1;
311            drop(id);
312            self.models.lock().unwrap().push(persisted.clone());
313            Ok(persisted)
314        }
315
316        async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
317            Ok(())
318        }
319
320        async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
321            Ok(())
322        }
323    }
324
325    #[tokio::test]
326    async fn test_register_model_basic() {
327        let repo = Arc::new(MockModelRepo::new());
328        let parser = Arc::new(NoopGgufParser);
329        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
330
331        let download = CompletedDownload {
332            primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
333            all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
334            quantization: Quantization::Q4KM,
335            repo_id: "test/model".to_string(),
336            commit_sha: "abc123".to_string(),
337            is_sharded: false,
338            total_bytes: 1024,
339            file_paths: None,
340            hf_tags: vec![],
341            hf_file_entries: vec![],
342        };
343
344        let result = registrar.register_model(&download).await;
345        assert!(result.is_ok());
346
347        let model = result.unwrap();
348        assert_eq!(model.name, "test/model");
349        assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
350        assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
351        assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
352    }
353
354    #[tokio::test]
355    async fn test_register_sharded_model() {
356        let repo = Arc::new(MockModelRepo::new());
357        let parser = Arc::new(NoopGgufParser);
358        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
359
360        let download = CompletedDownload {
361            primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
362            all_paths: vec![
363                PathBuf::from("/models/llama-00001-of-00004.gguf"),
364                PathBuf::from("/models/llama-00002-of-00004.gguf"),
365                PathBuf::from("/models/llama-00003-of-00004.gguf"),
366                PathBuf::from("/models/llama-00004-of-00004.gguf"),
367            ],
368            quantization: Quantization::Q8_0,
369            repo_id: "test/llama".to_string(),
370            commit_sha: "def456".to_string(),
371            is_sharded: true,
372            total_bytes: 4096,
373            file_paths: None,
374            hf_tags: vec![],
375            hf_file_entries: vec![],
376        };
377
378        let result = registrar.register_model(&download).await;
379        assert!(result.is_ok());
380
381        let model = result.unwrap();
382        assert_eq!(model.quantization, Some("Q8_0".to_string()));
383    }
384
385    #[tokio::test]
386    async fn test_register_model_from_path() {
387        let repo = Arc::new(MockModelRepo::new());
388        let parser = Arc::new(NoopGgufParser);
389        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
390
391        let result = registrar
392            .register_model_from_path(
393                "test/repo",
394                "commit123",
395                Path::new("/models/test-q4_0.gguf"),
396                "Q4_0",
397            )
398            .await;
399
400        assert!(result.is_ok());
401        let model = result.unwrap();
402        assert_eq!(model.name, "test/repo");
403    }
404}