1use std::path::Path;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12
13use crate::domain::{Model, NewModel, NewModelFile};
14use crate::download::Quantization;
15use crate::ports::{
16 CompletedDownload, GgufParserPort, ModelRegistrarPort, ModelRepository, RepositoryError,
17};
18
19#[async_trait]
24pub trait ModelFilesRepositoryPort: Send + Sync {
25 async fn insert(&self, model_file: &NewModelFile) -> anyhow::Result<()>;
27}
28
29pub struct ModelRegistrar {
34 model_repo: Arc<dyn ModelRepository>,
36 gguf_parser: Arc<dyn GgufParserPort>,
38 model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
40}
41
42impl ModelRegistrar {
43 pub fn new(
51 model_repo: Arc<dyn ModelRepository>,
52 gguf_parser: Arc<dyn GgufParserPort>,
53 model_files_repo: Option<Arc<dyn ModelFilesRepositoryPort>>,
54 ) -> Self {
55 Self {
56 model_repo,
57 gguf_parser,
58 model_files_repo,
59 }
60 }
61
62 fn filter_hf_tags(tags: &[String]) -> Vec<String> {
66 tags.iter()
67 .filter(|tag| {
68 let tag_lower = tag.to_lowercase();
69 !tag_lower.starts_with("arxiv:")
70 && !tag_lower.starts_with("region:")
71 && !tag_lower.starts_with("license:")
72 && !tag_lower.starts_with("dataset:")
73 && tag_lower != "gguf"
74 })
75 .cloned()
76 .collect()
77 }
78
79 fn merge_tags(gguf_tags: Vec<String>, hf_tags: &[String]) -> Vec<String> {
83 use std::collections::HashSet;
84
85 let mut seen = HashSet::new();
86 let mut result = Vec::new();
87
88 for tag in gguf_tags {
90 if seen.insert(tag.clone()) {
91 result.push(tag);
92 }
93 }
94
95 for tag in Self::filter_hf_tags(hf_tags) {
97 if seen.insert(tag.clone()) {
98 result.push(tag);
99 }
100 }
101
102 result
103 }
104}
105
106#[async_trait]
107impl ModelRegistrarPort for ModelRegistrar {
108 async fn register_model(&self, download: &CompletedDownload) -> Result<Model, RepositoryError> {
109 let file_path = download.db_path();
110
111 let gguf_metadata = self.gguf_parser.parse(file_path).ok();
113
114 let param_count_b = gguf_metadata
116 .as_ref()
117 .and_then(|m| m.param_count_b)
118 .unwrap_or(0.0);
119
120 let mut model = NewModel::new(
121 download.repo_id.clone(),
122 file_path.to_path_buf(),
123 param_count_b,
124 Utc::now(),
125 );
126
127 model.quantization = gguf_metadata
129 .as_ref()
130 .and_then(|m| m.quantization.clone())
131 .or_else(|| Some(download.quantization.to_string()));
132 model.architecture = gguf_metadata.as_ref().and_then(|m| m.architecture.clone());
133 model.context_length = gguf_metadata.as_ref().and_then(|m| m.context_length);
134 model.expert_count = gguf_metadata.as_ref().and_then(|m| m.expert_count);
135 model.expert_used_count = gguf_metadata.as_ref().and_then(|m| m.expert_used_count);
136 model.expert_shared_count = gguf_metadata.as_ref().and_then(|m| m.expert_shared_count);
137 if let Some(ref meta) = gguf_metadata {
138 model.metadata.clone_from(&meta.metadata);
139 }
140 model.hf_repo_id = Some(download.repo_id.clone());
141 model.hf_commit_sha = Some(download.commit_sha.clone());
142 model.hf_filename = Some(file_path.file_name().unwrap().to_string_lossy().to_string());
143 model.download_date = Some(Utc::now());
144
145 model.file_paths.clone_from(&download.file_paths);
147
148 let gguf_tags = gguf_metadata.as_ref().map_or_else(Vec::new, |meta| {
150 let capabilities = self.gguf_parser.detect_capabilities(meta);
151 capabilities.to_tags()
152 });
153
154 model.tags = Self::merge_tags(gguf_tags, &download.hf_tags);
156
157 let template = model.metadata.get("tokenizer.chat_template");
159 let name = model.metadata.get("general.name");
160 model.capabilities = crate::domain::infer_from_chat_template(
161 template.map(String::as_str),
162 name.map(String::as_str),
163 );
164
165 let registered = self.model_repo.insert(&model).await?;
166
167 if let Some(ref repo) = self.model_files_repo {
169 for (file_index, file_entry) in download.hf_file_entries.iter().enumerate() {
170 if let Some(size) = file_entry.size {
171 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
172 let model_file = NewModelFile::new(
173 registered.id,
174 file_entry.path.clone(),
175 file_index as i32,
176 size as i64,
177 file_entry.oid.clone(),
178 );
179
180 if let Err(e) = repo.insert(&model_file).await {
181 tracing::warn!(
183 model_id = registered.id,
184 file_path = %file_entry.path,
185 error = %e,
186 "Failed to insert model_files record - verification features may be unavailable"
187 );
188 }
189 }
190 }
191 }
192
193 Ok(registered)
194 }
195
196 async fn register_model_from_path(
197 &self,
198 repo_id: &str,
199 commit_sha: &str,
200 file_path: &Path,
201 quantization: &str,
202 ) -> Result<Model, RepositoryError> {
203 let download = CompletedDownload {
204 primary_path: file_path.to_path_buf(),
205 all_paths: vec![file_path.to_path_buf()],
206 quantization: Quantization::from_filename(quantization),
207 repo_id: repo_id.to_string(),
208 commit_sha: commit_sha.to_string(),
209 is_sharded: false,
210 total_bytes: 0,
211 file_paths: None,
212 hf_tags: vec![],
213 hf_file_entries: vec![],
214 };
215
216 self.register_model(&download).await
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::domain::Model;
224 use crate::ports::NoopGgufParser;
225 use std::path::PathBuf;
226 use std::sync::Mutex;
227
228 struct MockModelRepo {
230 models: Mutex<Vec<Model>>,
231 next_id: Mutex<i64>,
232 }
233
234 impl MockModelRepo {
235 fn new() -> Self {
236 Self {
237 models: Mutex::new(Vec::new()),
238 next_id: Mutex::new(1),
239 }
240 }
241 }
242
243 #[async_trait]
244 impl ModelRepository for MockModelRepo {
245 async fn list(&self) -> Result<Vec<Model>, RepositoryError> {
246 Ok(self.models.lock().unwrap().clone())
247 }
248
249 async fn get_by_id(&self, id: i64) -> Result<Model, RepositoryError> {
250 self.models
251 .lock()
252 .unwrap()
253 .iter()
254 .find(|m| m.id == id)
255 .cloned()
256 .ok_or_else(|| RepositoryError::NotFound(format!("id={id}")))
257 }
258
259 async fn get_by_name(&self, name: &str) -> Result<Model, RepositoryError> {
260 self.models
261 .lock()
262 .unwrap()
263 .iter()
264 .find(|m| m.name == name)
265 .cloned()
266 .ok_or_else(|| RepositoryError::NotFound(format!("name={name}")))
267 }
268
269 async fn insert(&self, model: &NewModel) -> Result<Model, RepositoryError> {
270 let mut id = self.next_id.lock().unwrap();
271 let persisted = Model {
272 id: *id,
273 name: model.name.clone(),
274 file_path: model.file_path.clone(),
275 param_count_b: model.param_count_b,
276 architecture: model.architecture.clone(),
277 quantization: model.quantization.clone(),
278 context_length: model.context_length,
279 expert_count: model.expert_count,
280 expert_used_count: model.expert_used_count,
281 expert_shared_count: model.expert_shared_count,
282 metadata: model.metadata.clone(),
283 added_at: model.added_at,
284 hf_repo_id: model.hf_repo_id.clone(),
285 hf_commit_sha: model.hf_commit_sha.clone(),
286 hf_filename: model.hf_filename.clone(),
287 capabilities: model.capabilities,
288 download_date: model.download_date,
289 last_update_check: model.last_update_check,
290 tags: model.tags.clone(),
291 inference_defaults: model.inference_defaults.clone(),
292 };
293 *id += 1;
294 drop(id);
295 self.models.lock().unwrap().push(persisted.clone());
296 Ok(persisted)
297 }
298
299 async fn update(&self, _model: &Model) -> Result<(), RepositoryError> {
300 Ok(())
301 }
302
303 async fn delete(&self, _id: i64) -> Result<(), RepositoryError> {
304 Ok(())
305 }
306 }
307
308 #[tokio::test]
309 async fn test_register_model_basic() {
310 let repo = Arc::new(MockModelRepo::new());
311 let parser = Arc::new(NoopGgufParser);
312 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
313
314 let download = CompletedDownload {
315 primary_path: PathBuf::from("/models/test-model-q4_k_m.gguf"),
316 all_paths: vec![PathBuf::from("/models/test-model-q4_k_m.gguf")],
317 quantization: Quantization::Q4KM,
318 repo_id: "test/model".to_string(),
319 commit_sha: "abc123".to_string(),
320 is_sharded: false,
321 total_bytes: 1024,
322 file_paths: None,
323 hf_tags: vec![],
324 hf_file_entries: vec![],
325 };
326
327 let result = registrar.register_model(&download).await;
328 assert!(result.is_ok());
329
330 let model = result.unwrap();
331 assert_eq!(model.name, "test/model");
332 assert_eq!(model.hf_repo_id, Some("test/model".to_string()));
333 assert_eq!(model.hf_commit_sha, Some("abc123".to_string()));
334 assert_eq!(model.quantization, Some("Q4_K_M".to_string()));
335 }
336
337 #[tokio::test]
338 async fn test_register_sharded_model() {
339 let repo = Arc::new(MockModelRepo::new());
340 let parser = Arc::new(NoopGgufParser);
341 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
342
343 let download = CompletedDownload {
344 primary_path: PathBuf::from("/models/llama-00001-of-00004.gguf"),
345 all_paths: vec![
346 PathBuf::from("/models/llama-00001-of-00004.gguf"),
347 PathBuf::from("/models/llama-00002-of-00004.gguf"),
348 PathBuf::from("/models/llama-00003-of-00004.gguf"),
349 PathBuf::from("/models/llama-00004-of-00004.gguf"),
350 ],
351 quantization: Quantization::Q8_0,
352 repo_id: "test/llama".to_string(),
353 commit_sha: "def456".to_string(),
354 is_sharded: true,
355 total_bytes: 4096,
356 file_paths: None,
357 hf_tags: vec![],
358 hf_file_entries: vec![],
359 };
360
361 let result = registrar.register_model(&download).await;
362 assert!(result.is_ok());
363
364 let model = result.unwrap();
365 assert_eq!(model.quantization, Some("Q8_0".to_string()));
366 }
367
368 #[tokio::test]
369 async fn test_register_model_from_path() {
370 let repo = Arc::new(MockModelRepo::new());
371 let parser = Arc::new(NoopGgufParser);
372 let registrar = ModelRegistrar::new(repo.clone(), parser, None);
373
374 let result = registrar
375 .register_model_from_path(
376 "test/repo",
377 "commit123",
378 Path::new("/models/test-q4_0.gguf"),
379 "Q4_0",
380 )
381 .await;
382
383 assert!(result.is_ok());
384 let model = result.unwrap();
385 assert_eq!(model.name, "test/repo");
386 }
387}