1use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel, NewModelFile};
14use crate::download::Quantization;
15use crate::ports::{
16 CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19#[async_trait]
24pub trait ModelFilesRepositoryPort: Send + Sync {
25 async fn insert(&self, model_file: &NewModelFile) -> anyhow::Result<()>;
27}
28
29pub struct ModelRegistrar {
34 model_repo: Arc<dyn ModelRepository>,
36 gguf_parser: Arc<dyn GgufParserPort>,
38 model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
40}
41
42impl ModelRegistrar {
43 pub fn new(
51 model_repo: Arc<dyn ModelRepository>,
52 gguf_parser: Arc<dyn GgufParserPort>,
53 model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
54 ) -> Self {
55 Self {
56 model_repo,
57 gguf_parser,
58 model_files_repo,
59 }
60 }
61
62 fn filter_hf_tags(tags: &[String]) -> Vec<String> {
66 tags.iter()
67 .filter(|tag| {
68 let tag_lower = tag.to_lowercase();
69 !tag_lower.starts_with("arxiv:")
70 && !tag_lower.starts_with("region:")
71 && !tag_lower.starts_with("license:")
72 && !tag_lower.starts_with("dataset:")
73 && tag_lower != "gguf"
74 })
75 .cloned()
76 .collect()
77 }
78
79 fn merge_tags(gguf_tags: Vec<String>, hf_tags: &[String]) -> Vec<String> {
83 use std::collections::HashSet;
84
85 let mut seen = HashSet::new();
86 let mut result = Vec::new();
87
88 for tag in gguf_tags {
90 if seen.insert(tag.clone()) {
91 result.push(tag);
92 }
93 }
94
95 for tag in Self::filter_hf_tags(hf_tags) {
97 if seen.insert(tag.clone()) {
98 result.push(tag);
99 }
100 }
101
102 result
103 }
104}
105
106#[async_trait]
107impl ModelRegistrarPort for ModelRegistrar {
108 async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
109 let file_path = download.db_path();
110
111 let gguf_metadata = self.gguf_parser.parse(file_path).ok();
113
114 let param_count_b = gguf_metadata
116 .as_ref()
117 .and_then(|m| m.param_count_b)
118 .unwrap_or(0.0);
119
120 let mut model = NewModel::new(
121 download.repo_id.clone(),
122 file_path.to_path_buf(),
123 param_count_b,
124 Utc::now(),
125 );
126
127 model.quantization = gguf_metadata
129 .as_ref()
130 .and_then(|m| m.quantization.clone())
131 .or_else(|| Some(download.quantization.to_string()));
132 model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
133 model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
134 model.expert_count = gguf_metadata.as_ref().and_then(|m| m.expert_count);
135 model.expert_used_count = gguf_metadata.as_ref().and_then(|m| m.expert_used_count);
136 model.expert_shared_count = gguf_metadata.as_ref().and_then(|m| m.expert_shared_count);
137 if let Some(ref meta) = gguf_metadata {
138 model.metadata.clone_from(&meta.metadata);
139 }
140 model.hf_repo_id = Some(download.repo_id.clone());
141 model.hf_commit_sha = Some(download.commit_sha.clone());
142 model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
143 model.download_date = Some(Utc::now());
144
145 model.file_paths.clone_from(&download.file_paths);
147
148 let gguf_tags = gguf_metadata.as_ref().map_or_else(Vec::new, |meta| {
150 let capabilities = self.gguf_parser.detect_capabilities(meta);
151 capabilities.to_tags()
152 });
153
154 model.tags = Self::merge_tags(gguf_tags, &download.hf_tags);
156
157 if model.inference_defaults.is_none()
161 && model
162 .tags
163 .iter()
164 .any(|t| t.eq_ignore_ascii_case("reasoning"))
165 {
166 model.inference_defaults = Some(crate::domain::InferenceConfig::reasoning_profile());
167 }
168
169 let template = model.metadata.get("tokenizer.chat_template");
173 let name = model.metadata.get("general.name");
174 let arch = model.metadata.get("general.architecture");
175 let from_template = crate::domain::infer_from_chat_template(
176 template.map(String::as_str),
177 name.map(String::as_str),
178 );
179 let from_arch = crate::domain::capabilities_from_architecture(arch.map(String::as_str));
180 model.capabilities = from_template | from_arch;
181
182 let registered = self.model_repo.insert(&model).await?;
183
184 if let Some(ref repo) = self.model_files_repo {
186 for (file_index, file_entry) in download.hf_file_entries.iter().enumerate() {
187 if let Some(size) = file_entry.size {
188 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
189 let model_file = NewModelFile::new(
190 registered.id,
191 file_entry.path.clone(),
192 file_index as i32,
193 size as i64,
194 file_entry.oid.clone(),
195 );
196
197 if let Err(e) = repo.insert(&model_file).await {
198 tracing::warn!(
200 model_id = registered.id,
201 file_path = %file_entry.path,
202 error = %e,
203 "Failed to insert model_files record - verification features may be unavailable"
204 );
205 }
206 }
207 }
208 }
209
210 Ok(registered)
211 }
212
213 async fn register_model_from_path(
214 &self,
215 repo_id: &str,
216 commit_sha: &str,
217 file_path: &Path,
218 quantization: &str,
219 ) -> Result<Model, RepositoryError> {
220 let download = CompletedDownload {
221 primary_path: file_path.to_path_buf(),
222 all_paths: vec![file_path.to_path_buf()],
223 quantization: Quantization::from_filename(quantization),
224 repo_id: repo_id.to_string(),
225 commit_sha: commit_sha.to_string(),
226 is_sharded: false,
227 total_bytes: 0,
228 file_paths: None,
229 hf_tags: vec![],
230 hf_file_entries: vec![],
231 };
232
233 self.register_model(&download).await
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::domain::Model;
241 use crate::ports::NoopGgufParser;
242 use std::path::PathBuf;
243 use std::sync::Mutex;
244
245 struct MockModelRepo {
247 models: Mutex<Vec<Model>>,
248 next_id: Mutex<i64>,
249 }
250
251 impl MockModelRepo {
252 fn new() -> Self {
253 Self {
254 models: Mutex::new(Vec::new()),
255 next_id: Mutex::new(1),
256 }
257 }
258 }
259
260 #[async_trait]
261 impl ModelRepository for MockModelRepo {
262 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
263 Ok(self.models.lock().unwrap().clone())
264 }
265
266 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
267 self.models
268 .lock()
269 .unwrap()
270 .iter()
271 .find(|m| m.id == id)
272 .cloned()
273 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
274 }
275
276 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
277 self.models
278 .lock()
279 .unwrap()
280 .iter()
281 .find(|m| m.name == name)
282 .cloned()
283 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
284 }
285
286 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
287 let mut id = self.next_id.lock().unwrap();
288 let persisted = Model {
289 id: *id,
290 name: model.name.clone(),
291 file_path: model.file_path.clone(),
292 param_count_b: model.param_count_b,
293 architecture: model.architecture.clone(),
294 quantization: model.quantization.clone(),
295 context_length: model.context_length,
296 expert_count: model.expert_count,
297 expert_used_count: model.expert_used_count,
298 expert_shared_count: model.expert_shared_count,
299 metadata: model.metadata.clone(),
300 added_at: model.added_at,
301 hf_repo_id: model.hf_repo_id.clone(),
302 hf_commit_sha: model.hf_commit_sha.clone(),
303 hf_filename: model.hf_filename.clone(),
304 capabilities: model.capabilities,
305 download_date: model.download_date,
306 last_update_check: model.last_update_check,
307 tags: model.tags.clone(),
308 inference_defaults: model.inference_defaults.clone(),
309 };
310 *id += 1;
311 drop(id);
312 self.models.lock().unwrap().push(persisted.clone());
313 Ok(persisted)
314 }
315
316 async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
317 Ok(())
318 }
319
320 async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
321 Ok(())
322 }
323 }
324
325 #[tokio::test]
326 async fn test_register_model_basic() {
327 let repo = Arc::new(MockModelRepo::new());
328 let parser = Arc::new(NoopGgufParser);
329 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
330
331 let download = CompletedDownload {
332 primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
333 all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
334 quantization: Quantization::Q4KM,
335 repo_id: "test/model".to_string(),
336 commit_sha: "abc123".to_string(),
337 is_sharded: false,
338 total_bytes: 1024,
339 file_paths: None,
340 hf_tags: vec![],
341 hf_file_entries: vec![],
342 };
343
344 let result = registrar.register_model(&download).await;
345 assert!(result.is_ok());
346
347 let model = result.unwrap();
348 assert_eq!(model.name, "test/model");
349 assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
350 assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
351 assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
352 }
353
354 #[tokio::test]
355 async fn test_register_sharded_model() {
356 let repo = Arc::new(MockModelRepo::new());
357 let parser = Arc::new(NoopGgufParser);
358 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
359
360 let download = CompletedDownload {
361 primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
362 all_paths: vec![
363 PathBuf::from("/models/llama-00001-of-00004.gguf"),
364 PathBuf::from("/models/llama-00002-of-00004.gguf"),
365 PathBuf::from("/models/llama-00003-of-00004.gguf"),
366 PathBuf::from("/models/llama-00004-of-00004.gguf"),
367 ],
368 quantization: Quantization::Q8_0,
369 repo_id: "test/llama".to_string(),
370 commit_sha: "def456".to_string(),
371 is_sharded: true,
372 total_bytes: 4096,
373 file_paths: None,
374 hf_tags: vec![],
375 hf_file_entries: vec![],
376 };
377
378 let result = registrar.register_model(&download).await;
379 assert!(result.is_ok());
380
381 let model = result.unwrap();
382 assert_eq!(model.quantization, Some("Q8_0".to_string()));
383 }
384
385 #[tokio::test]
386 async fn test_register_model_from_path() {
387 let repo = Arc::new(MockModelRepo::new());
388 let parser = Arc::new(NoopGgufParser);
389 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
390
391 let result = registrar
392 .register_model_from_path(
393 "test/repo",
394 "commit123",
395 Path::new("/models/test-q4_0.gguf"),
396 "Q4_0",
397 )
398 .await;
399
400 assert!(result.is_ok());
401 let model = result.unwrap();
402 assert_eq!(model.name, "test/repo");
403 }
404}