1use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel};
14use crate::download::Quantization;
15use crate::ports::{
16 CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19pub struct ModelRegistrar {
24 model_repo: Arc<dyn ModelRepository>,
26 gguf_parser: Arc<dyn GgufParserPort>,
28}
29
30impl ModelRegistrar {
31 pub fn new(model_repo: Arc<dyn ModelRepository>, gguf_parser: Arc<dyn GgufParserPort>) -> Self {
38 Self {
39 model_repo,
40 gguf_parser,
41 }
42 }
43}
44
45#[async_trait]
46impl ModelRegistrarPort for ModelRegistrar {
47 async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
48 let file_path = download.db_path();
49
50 let gguf_metadata = self.gguf_parser.parse(file_path).ok();
52
53 let param_count_b = gguf_metadata
55 .as_ref()
56 .and_then(|m| m.param_count_b)
57 .unwrap_or(0.0);
58
59 let mut model = NewModel::new(
60 download.repo_id.clone(),
61 file_path.to_path_buf(),
62 param_count_b,
63 Utc::now(),
64 );
65
66 model.quantization = gguf_metadata
68 .as_ref()
69 .and_then(|m| m.quantization.clone())
70 .or_else(|| Some(download.quantization.to_string()));
71 model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
72 model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
73 if let Some(ref meta) = gguf_metadata {
74 model.metadata.clone_from(&meta.metadata);
75 }
76 model.hf_repo_id = Some(download.repo_id.clone());
77 model.hf_commit_sha = Some(download.commit_sha.clone());
78 model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
79 model.download_date = Some(Utc::now());
80
81 model.file_paths.clone_from(&download.file_paths);
83
84 if let Some(ref meta) = gguf_metadata {
86 let capabilities = self.gguf_parser.detect_capabilities(meta);
87 model.tags = capabilities.to_tags();
88 }
89
90 let template = model.metadata.get("tokenizer.chat_template");
92 let name = model.metadata.get("general.name");
93 model.capabilities = crate::domain::infer_from_chat_template(
94 template.map(String::as_str),
95 name.map(String::as_str),
96 );
97
98 let registered = self.model_repo.insert(&model).await?;
99
100 Ok(registered)
101 }
102
103 async fn register_model_from_path(
104 &self,
105 repo_id: &str,
106 commit_sha: &str,
107 file_path: &Path,
108 quantization: &str,
109 ) -> Result<Model, RepositoryError> {
110 let download = CompletedDownload {
111 primary_path: file_path.to_path_buf(),
112 all_paths: vec![file_path.to_path_buf()],
113 quantization: Quantization::from_filename(quantization),
114 repo_id: repo_id.to_string(),
115 commit_sha: commit_sha.to_string(),
116 is_sharded: false,
117 total_bytes: 0,
118 file_paths: None,
119 };
120
121 self.register_model(&download).await
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::domain::Model;
129 use crate::ports::NoopGgufParser;
130 use std::path::PathBuf;
131 use std::sync::Mutex;
132
133 struct MockModelRepo {
135 models: Mutex<Vec<Model>>,
136 next_id: Mutex<i64>,
137 }
138
139 impl MockModelRepo {
140 fn new() -> Self {
141 Self {
142 models: Mutex::new(Vec::new()),
143 next_id: Mutex::new(1),
144 }
145 }
146 }
147
148 #[async_trait]
149 impl ModelRepository for MockModelRepo {
150 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
151 Ok(self.models.lock().unwrap().clone())
152 }
153
154 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
155 self.models
156 .lock()
157 .unwrap()
158 .iter()
159 .find(|m| m.id == id)
160 .cloned()
161 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
162 }
163
164 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
165 self.models
166 .lock()
167 .unwrap()
168 .iter()
169 .find(|m| m.name == name)
170 .cloned()
171 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
172 }
173
174 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
175 let mut id = self.next_id.lock().unwrap();
176 let persisted = Model {
177 id: *id,
178 name: model.name.clone(),
179 file_path: model.file_path.clone(),
180 param_count_b: model.param_count_b,
181 architecture: model.architecture.clone(),
182 quantization: model.quantization.clone(),
183 context_length: model.context_length,
184 metadata: model.metadata.clone(),
185 added_at: model.added_at,
186 hf_repo_id: model.hf_repo_id.clone(),
187 hf_commit_sha: model.hf_commit_sha.clone(),
188 hf_filename: model.hf_filename.clone(),
189 capabilities: model.capabilities,
190 download_date: model.download_date,
191 last_update_check: model.last_update_check,
192 tags: model.tags.clone(),
193 inference_defaults: model.inference_defaults.clone(),
194 };
195 *id += 1;
196 drop(id);
197 self.models.lock().unwrap().push(persisted.clone());
198 Ok(persisted)
199 }
200
201 async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
202 Ok(())
203 }
204
205 async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
206 Ok(())
207 }
208 }
209
210 #[tokio::test]
211 async fn test_register_model_basic() {
212 let repo = Arc::new(MockModelRepo::new());
213 let parser = Arc::new(NoopGgufParser);
214 let registrar = ModelRegistrar::new(repo.clone(), parser);
215
216 let download = CompletedDownload {
217 primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
218 all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
219 quantization: Quantization::Q4KM,
220 repo_id: "test/model".to_string(),
221 commit_sha: "abc123".to_string(),
222 is_sharded: false,
223 total_bytes: 1024,
224 file_paths: None,
225 };
226
227 let result = registrar.register_model(&download).await;
228 assert!(result.is_ok());
229
230 let model = result.unwrap();
231 assert_eq!(model.name, "test/model");
232 assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
233 assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
234 assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
235 }
236
237 #[tokio::test]
238 async fn test_register_sharded_model() {
239 let repo = Arc::new(MockModelRepo::new());
240 let parser = Arc::new(NoopGgufParser);
241 let registrar = ModelRegistrar::new(repo.clone(), parser);
242
243 let download = CompletedDownload {
244 primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
245 all_paths: vec![
246 PathBuf::from("/models/llama-00001-of-00004.gguf"),
247 PathBuf::from("/models/llama-00002-of-00004.gguf"),
248 PathBuf::from("/models/llama-00003-of-00004.gguf"),
249 PathBuf::from("/models/llama-00004-of-00004.gguf"),
250 ],
251 quantization: Quantization::Q8_0,
252 repo_id: "test/llama".to_string(),
253 commit_sha: "def456".to_string(),
254 is_sharded: true,
255 total_bytes: 4096,
256 file_paths: None,
257 };
258
259 let result = registrar.register_model(&download).await;
260 assert!(result.is_ok());
261
262 let model = result.unwrap();
263 assert_eq!(model.quantization, Some("Q8_0".to_string()));
264 }
265
266 #[tokio::test]
267 async fn test_register_model_from_path() {
268 let repo = Arc::new(MockModelRepo::new());
269 let parser = Arc::new(NoopGgufParser);
270 let registrar = ModelRegistrar::new(repo.clone(), parser);
271
272 let result = registrar
273 .register_model_from_path(
274 "test/repo",
275 "commit123",
276 Path::new("/models/test-q4_0.gguf"),
277 "Q4_0",
278 )
279 .await;
280
281 assert!(result.is_ok());
282 let model = result.unwrap();
283 assert_eq!(model.name, "test/repo");
284 }
285}