gglib_core/services/
model_registrar.rs

1//! Model registrar service implementation.
2//!
3//! This service implements `ModelRegistrarPort` using the `ModelRepository`
4//! and `GgufParserPort` dependencies. It's used by the download manager
5//! to register completed downloads.
6
7use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel, NewModelFile};
14use crate::download::Quantization;
15use crate::ports::{
16    CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19/// Repository trait for model files metadata.
20///
21/// We don't depend on `gglib_db` directly - adapters inject the implementation.
22/// This type is re-exported from `gglib_db` for use in adapters.
23#[async_trait]
24pub trait ModelFilesRepositoryPort: Send + Sync {
25    /// Insert a new model file record.
26    async fn insert(&self, model_file: &NewModelFile) -> anyhow::Result<()>;
27}
28
29/// Implementation of the model registrar port.
30///
31/// This service composes over `ModelRepository` for persistence and
32/// `GgufParserPort` for metadata extraction.
33pub struct ModelRegistrar {
34    /// Repository for persisting models.
35    model_repo: Arc<dyn ModelRepository>,
36    /// Parser for extracting GGUF metadata.
37    gguf_parser: Arc<dyn GgufParserPort>,
38    /// Repository for persisting model file metadata.
39    model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
40}
41
42impl ModelRegistrar {
43    /// Create a new model registrar.
44    ///
45    /// # Arguments
46    ///
47    /// * `model_repo` - Repository for persisting models
48    /// * `gguf_parser` - Parser for extracting GGUF metadata
49    /// * `model_files_repo` - Optional repository for persisting model file metadata
50    pub fn new(
51        model_repo: Arc<dyn ModelRepository>,
52        gguf_parser: Arc<dyn GgufParserPort>,
53        model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
54    ) -> Self {
55        Self {
56            model_repo,
57            gguf_parser,
58            model_files_repo,
59        }
60    }
61
62    /// Filter `HuggingFace` tags using a blocklist.
63    ///
64    /// Removes noisy tags like `gguf`, `arxiv:*`, `region:*`, `license:*`, `dataset:*`.
65    fn filter_hf_tags(tags: &[String]) -> Vec<String> {
66        tags.iter()
67            .filter(|tag| {
68                let tag_lower = tag.to_lowercase();
69                !tag_lower.starts_with("arxiv:")
70                    && !tag_lower.starts_with("region:")
71                    && !tag_lower.starts_with("license:")
72                    && !tag_lower.starts_with("dataset:")
73                    && tag_lower != "gguf"
74            })
75            .cloned()
76            .collect()
77    }
78
79    /// Merge GGUF-derived tags with filtered HF tags, removing duplicates.
80    ///
81    /// GGUF-derived tags are prioritized (appear first in the result).
82    fn merge_tags(gguf_tags: Vec<String>, hf_tags: &[String]) -> Vec<String> {
83        use std::collections::HashSet;
84
85        let mut seen = HashSet::new();
86        let mut result = Vec::new();
87
88        // Add GGUF tags first
89        for tag in gguf_tags {
90            if seen.insert(tag.clone()) {
91                result.push(tag);
92            }
93        }
94
95        // Add filtered HF tags
96        for tag in Self::filter_hf_tags(hf_tags) {
97            if seen.insert(tag.clone()) {
98                result.push(tag);
99            }
100        }
101
102        result
103    }
104}
105
106#[async_trait]
107impl ModelRegistrarPort for ModelRegistrar {
108    async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
109        let file_path = download.db_path();
110
111        // Parse GGUF metadata from the downloaded file
112        let gguf_metadata = self.gguf_parser.parse(file_path).ok();
113
114        // Extract param_count_b from metadata, fall back to 0.0
115        let param_count_b = gguf_metadata
116            .as_ref()
117            .and_then(|m| m.param_count_b)
118            .unwrap_or(0.0);
119
120        let mut model = NewModel::new(
121            download.repo_id.clone(),
122            file_path.to_path_buf(),
123            param_count_b,
124            Utc::now(),
125        );
126
127        // Use extracted metadata where available, with fallbacks
128        model.quantization = gguf_metadata
129            .as_ref()
130            .and_then(|m| m.quantization.clone())
131            .or_else(|| Some(download.quantization.to_string()));
132        model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
133        model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
134        model.expert_count = gguf_metadata.as_ref().and_then(|m| m.expert_count);
135        model.expert_used_count = gguf_metadata.as_ref().and_then(|m| m.expert_used_count);
136        model.expert_shared_count = gguf_metadata.as_ref().and_then(|m| m.expert_shared_count);
137        if let Some(ref meta) = gguf_metadata {
138            model.metadata.clone_from(&meta.metadata);
139        }
140        model.hf_repo_id = Some(download.repo_id.clone());
141        model.hf_commit_sha = Some(download.commit_sha.clone());
142        model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
143        model.download_date = Some(Utc::now());
144
145        // Pass through file_paths for sharded models
146        model.file_paths.clone_from(&download.file_paths);
147
148        // Auto-detect capabilities from metadata and merge with HF tags
149        let gguf_tags = gguf_metadata.as_ref().map_or_else(Vec::new, |meta| {
150            let capabilities = self.gguf_parser.detect_capabilities(meta);
151            capabilities.to_tags()
152        });
153
154        // Merge GGUF-derived tags with filtered HF tags (deduplicated)
155        model.tags = Self::merge_tags(gguf_tags, &download.hf_tags);
156
157        // Infer model capabilities from chat template
158        let template = model.metadata.get("tokenizer.chat_template");
159        let name = model.metadata.get("general.name");
160        model.capabilities = crate::domain::infer_from_chat_template(
161            template.map(String::as_str),
162            name.map(String::as_str),
163        );
164
165        let registered = self.model_repo.insert(&model).await?;
166
167        // Insert model_files records with OIDs for each shard (if repo is available)
168        if let Some(ref repo) = self.model_files_repo {
169            for (file_index, file_entry) in download.hf_file_entries.iter().enumerate() {
170                if let Some(size) = file_entry.size {
171                    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
172                    let model_file = NewModelFile::new(
173                        registered.id,
174                        file_entry.path.clone(),
175                        file_index as i32,
176                        size as i64,
177                        file_entry.oid.clone(),
178                    );
179
180                    if let Err(e) = repo.insert(&model_file).await {
181                        // Soft fail - log but don't propagate error
182                        tracing::warn!(
183                            model_id = registered.id,
184                            file_path = %file_entry.path,
185                            error = %e,
186                            "Failed to insert model_files record - verification features may be unavailable"
187                        );
188                    }
189                }
190            }
191        }
192
193        Ok(registered)
194    }
195
196    async fn register_model_from_path(
197        &self,
198        repo_id: &str,
199        commit_sha: &str,
200        file_path: &Path,
201        quantization: &str,
202    ) -> Result<Model, RepositoryError> {
203        let download = CompletedDownload {
204            primary_path: file_path.to_path_buf(),
205            all_paths: vec![file_path.to_path_buf()],
206            quantization: Quantization::from_filename(quantization),
207            repo_id: repo_id.to_string(),
208            commit_sha: commit_sha.to_string(),
209            is_sharded: false,
210            total_bytes: 0,
211            file_paths: None,
212            hf_tags: vec![],
213            hf_file_entries: vec![],
214        };
215
216        self.register_model(&download).await
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::domain::Model;
224    use crate::ports::NoopGgufParser;
225    use std::path::PathBuf;
226    use std::sync::Mutex;
227
228    /// Mock model repository for testing.
229    struct MockModelRepo {
230        models: Mutex<Vec<Model>>,
231        next_id: Mutex<i64>,
232    }
233
234    impl MockModelRepo {
235        fn new() -> Self {
236            Self {
237                models: Mutex::new(Vec::new()),
238                next_id: Mutex::new(1),
239            }
240        }
241    }
242
243    #[async_trait]
244    impl ModelRepository for MockModelRepo {
245        async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
246            Ok(self.models.lock().unwrap().clone())
247        }
248
249        async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
250            self.models
251                .lock()
252                .unwrap()
253                .iter()
254                .find(|m| m.id == id)
255                .cloned()
256                .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
257        }
258
259        async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
260            self.models
261                .lock()
262                .unwrap()
263                .iter()
264                .find(|m| m.name == name)
265                .cloned()
266                .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
267        }
268
269        async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
270            let mut id = self.next_id.lock().unwrap();
271            let persisted = Model {
272                id: *id,
273                name: model.name.clone(),
274                file_path: model.file_path.clone(),
275                param_count_b: model.param_count_b,
276                architecture: model.architecture.clone(),
277                quantization: model.quantization.clone(),
278                context_length: model.context_length,
279                expert_count: model.expert_count,
280                expert_used_count: model.expert_used_count,
281                expert_shared_count: model.expert_shared_count,
282                metadata: model.metadata.clone(),
283                added_at: model.added_at,
284                hf_repo_id: model.hf_repo_id.clone(),
285                hf_commit_sha: model.hf_commit_sha.clone(),
286                hf_filename: model.hf_filename.clone(),
287                capabilities: model.capabilities,
288                download_date: model.download_date,
289                last_update_check: model.last_update_check,
290                tags: model.tags.clone(),
291                inference_defaults: model.inference_defaults.clone(),
292            };
293            *id += 1;
294            drop(id);
295            self.models.lock().unwrap().push(persisted.clone());
296            Ok(persisted)
297        }
298
299        async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
300            Ok(())
301        }
302
303        async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
304            Ok(())
305        }
306    }
307
308    #[tokio::test]
309    async fn test_register_model_basic() {
310        let repo = Arc::new(MockModelRepo::new());
311        let parser = Arc::new(NoopGgufParser);
312        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
313
314        let download = CompletedDownload {
315            primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
316            all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
317            quantization: Quantization::Q4KM,
318            repo_id: "test/model".to_string(),
319            commit_sha: "abc123".to_string(),
320            is_sharded: false,
321            total_bytes: 1024,
322            file_paths: None,
323            hf_tags: vec![],
324            hf_file_entries: vec![],
325        };
326
327        let result = registrar.register_model(&download).await;
328        assert!(result.is_ok());
329
330        let model = result.unwrap();
331        assert_eq!(model.name, "test/model");
332        assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
333        assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
334        assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
335    }
336
337    #[tokio::test]
338    async fn test_register_sharded_model() {
339        let repo = Arc::new(MockModelRepo::new());
340        let parser = Arc::new(NoopGgufParser);
341        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
342
343        let download = CompletedDownload {
344            primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
345            all_paths: vec![
346                PathBuf::from("/models/llama-00001-of-00004.gguf"),
347                PathBuf::from("/models/llama-00002-of-00004.gguf"),
348                PathBuf::from("/models/llama-00003-of-00004.gguf"),
349                PathBuf::from("/models/llama-00004-of-00004.gguf"),
350            ],
351            quantization: Quantization::Q8_0,
352            repo_id: "test/llama".to_string(),
353            commit_sha: "def456".to_string(),
354            is_sharded: true,
355            total_bytes: 4096,
356            file_paths: None,
357            hf_tags: vec![],
358            hf_file_entries: vec![],
359        };
360
361        let result = registrar.register_model(&download).await;
362        assert!(result.is_ok());
363
364        let model = result.unwrap();
365        assert_eq!(model.quantization, Some("Q8_0".to_string()));
366    }
367
368    #[tokio::test]
369    async fn test_register_model_from_path() {
370        let repo = Arc::new(MockModelRepo::new());
371        let parser = Arc::new(NoopGgufParser);
372        let registrar = ModelRegistrar::new(repo.clone(), parser, None);
373
374        let result = registrar
375            .register_model_from_path(
376                "test/repo",
377                "commit123",
378                Path::new("/models/test-q4_0.gguf"),
379                "Q4_0",
380            )
381            .await;
382
383        assert!(result.is_ok());
384        let model = result.unwrap();
385        assert_eq!(model.name, "test/repo");
386    }
387}