gglib_core/services/
model_registrar.rs

1//! Model registrar service implementation.
2//!
3//! This service implements `ModelRegistrarPort` using the `ModelRepository`
4//! and `GgufParserPort` dependencies. It's used by the download manager
5//! to register completed downloads.
6
7use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel};
14use crate::download::Quantization;
15use crate::ports::{
16    CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19/// Implementation of the model registrar port.
20///
21/// This service composes over `ModelRepository` for persistence and
22/// `GgufParserPort` for metadata extraction.
23pub struct ModelRegistrar {
24    /// Repository for persisting models.
25    model_repo: Arc<dyn ModelRepository>,
26    /// Parser for extracting GGUF metadata.
27    gguf_parser: Arc<dyn GgufParserPort>,
28}
29
30impl ModelRegistrar {
31    /// Create a new model registrar.
32    ///
33    /// # Arguments
34    ///
35    /// * `model_repo` - Repository for persisting models
36    /// * `gguf_parser` - Parser for extracting GGUF metadata
37    pub fn new(model_repo: Arc<dyn ModelRepository>, gguf_parser: Arc<dyn GgufParserPort>) -> Self {
38        Self {
39            model_repo,
40            gguf_parser,
41        }
42    }
43}
44
45#[async_trait]
46impl ModelRegistrarPort for ModelRegistrar {
47    async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
48        let file_path = download.db_path();
49
50        // Parse GGUF metadata from the downloaded file
51        let gguf_metadata = self.gguf_parser.parse(file_path).ok();
52
53        // Extract param_count_b from metadata, fall back to 0.0
54        let param_count_b = gguf_metadata
55            .as_ref()
56            .and_then(|m| m.param_count_b)
57            .unwrap_or(0.0);
58
59        let mut model = NewModel::new(
60            download.repo_id.clone(),
61            file_path.to_path_buf(),
62            param_count_b,
63            Utc::now(),
64        );
65
66        // Use extracted metadata where available, with fallbacks
67        model.quantization = gguf_metadata
68            .as_ref()
69            .and_then(|m| m.quantization.clone())
70            .or_else(|| Some(download.quantization.to_string()));
71        model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
72        model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
73        if let Some(ref meta) = gguf_metadata {
74            model.metadata.clone_from(&meta.metadata);
75        }
76        model.hf_repo_id = Some(download.repo_id.clone());
77        model.hf_commit_sha = Some(download.commit_sha.clone());
78        model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
79        model.download_date = Some(Utc::now());
80
81        // Pass through file_paths for sharded models
82        model.file_paths.clone_from(&download.file_paths);
83
84        // Auto-detect capabilities from metadata
85        if let Some(ref meta) = gguf_metadata {
86            let capabilities = self.gguf_parser.detect_capabilities(meta);
87            model.tags = capabilities.to_tags();
88        }
89
90        // Infer model capabilities from chat template
91        let template = model.metadata.get("tokenizer.chat_template");
92        let name = model.metadata.get("general.name");
93        model.capabilities = crate::domain::infer_from_chat_template(
94            template.map(String::as_str),
95            name.map(String::as_str),
96        );
97
98        let registered = self.model_repo.insert(&model).await?;
99
100        Ok(registered)
101    }
102
103    async fn register_model_from_path(
104        &self,
105        repo_id: &str,
106        commit_sha: &str,
107        file_path: &Path,
108        quantization: &str,
109    ) -> Result<Model, RepositoryError> {
110        let download = CompletedDownload {
111            primary_path: file_path.to_path_buf(),
112            all_paths: vec![file_path.to_path_buf()],
113            quantization: Quantization::from_filename(quantization),
114            repo_id: repo_id.to_string(),
115            commit_sha: commit_sha.to_string(),
116            is_sharded: false,
117            total_bytes: 0,
118            file_paths: None,
119        };
120
121        self.register_model(&download).await
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::domain::Model;
129    use crate::ports::NoopGgufParser;
130    use std::path::PathBuf;
131    use std::sync::Mutex;
132
133    /// Mock model repository for testing.
134    struct MockModelRepo {
135        models: Mutex<Vec<Model>>,
136        next_id: Mutex<i64>,
137    }
138
139    impl MockModelRepo {
140        fn new() -> Self {
141            Self {
142                models: Mutex::new(Vec::new()),
143                next_id: Mutex::new(1),
144            }
145        }
146    }
147
148    #[async_trait]
149    impl ModelRepository for MockModelRepo {
150        async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
151            Ok(self.models.lock().unwrap().clone())
152        }
153
154        async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
155            self.models
156                .lock()
157                .unwrap()
158                .iter()
159                .find(|m| m.id == id)
160                .cloned()
161                .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
162        }
163
164        async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
165            self.models
166                .lock()
167                .unwrap()
168                .iter()
169                .find(|m| m.name == name)
170                .cloned()
171                .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
172        }
173
174        async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
175            let mut id = self.next_id.lock().unwrap();
176            let persisted = Model {
177                id: *id,
178                name: model.name.clone(),
179                file_path: model.file_path.clone(),
180                param_count_b: model.param_count_b,
181                architecture: model.architecture.clone(),
182                quantization: model.quantization.clone(),
183                context_length: model.context_length,
184                metadata: model.metadata.clone(),
185                added_at: model.added_at,
186                hf_repo_id: model.hf_repo_id.clone(),
187                hf_commit_sha: model.hf_commit_sha.clone(),
188                hf_filename: model.hf_filename.clone(),
189                capabilities: model.capabilities,
190                download_date: model.download_date,
191                last_update_check: model.last_update_check,
192                tags: model.tags.clone(),
193                inference_defaults: model.inference_defaults.clone(),
194            };
195            *id += 1;
196            drop(id);
197            self.models.lock().unwrap().push(persisted.clone());
198            Ok(persisted)
199        }
200
201        async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
202            Ok(())
203        }
204
205        async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
206            Ok(())
207        }
208    }
209
210    #[tokio::test]
211    async fn test_register_model_basic() {
212        let repo = Arc::new(MockModelRepo::new());
213        let parser = Arc::new(NoopGgufParser);
214        let registrar = ModelRegistrar::new(repo.clone(), parser);
215
216        let download = CompletedDownload {
217            primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
218            all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
219            quantization: Quantization::Q4KM,
220            repo_id: "test/model".to_string(),
221            commit_sha: "abc123".to_string(),
222            is_sharded: false,
223            total_bytes: 1024,
224            file_paths: None,
225        };
226
227        let result = registrar.register_model(&download).await;
228        assert!(result.is_ok());
229
230        let model = result.unwrap();
231        assert_eq!(model.name, "test/model");
232        assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
233        assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
234        assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
235    }
236
237    #[tokio::test]
238    async fn test_register_sharded_model() {
239        let repo = Arc::new(MockModelRepo::new());
240        let parser = Arc::new(NoopGgufParser);
241        let registrar = ModelRegistrar::new(repo.clone(), parser);
242
243        let download = CompletedDownload {
244            primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
245            all_paths: vec![
246                PathBuf::from("/models/llama-00001-of-00004.gguf"),
247                PathBuf::from("/models/llama-00002-of-00004.gguf"),
248                PathBuf::from("/models/llama-00003-of-00004.gguf"),
249                PathBuf::from("/models/llama-00004-of-00004.gguf"),
250            ],
251            quantization: Quantization::Q8_0,
252            repo_id: "test/llama".to_string(),
253            commit_sha: "def456".to_string(),
254            is_sharded: true,
255            total_bytes: 4096,
256            file_paths: None,
257        };
258
259        let result = registrar.register_model(&download).await;
260        assert!(result.is_ok());
261
262        let model = result.unwrap();
263        assert_eq!(model.quantization, Some("Q8_0".to_string()));
264    }
265
266    #[tokio::test]
267    async fn test_register_model_from_path() {
268        let repo = Arc::new(MockModelRepo::new());
269        let parser = Arc::new(NoopGgufParser);
270        let registrar = ModelRegistrar::new(repo.clone(), parser);
271
272        let result = registrar
273            .register_model_from_path(
274                "test/repo",
275                "commit123",
276                Path::new("/models/test-q4_0.gguf"),
277                "Q4_0",
278            )
279            .await;
280
281        assert!(result.is_ok());
282        let model = result.unwrap();
283        assert_eq!(model.name, "test/repo");
284    }
285}