1use std::collections::HashMap;
10use std::fs::File;
11use std::io::Read;
12use std::path::Path;
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use chrono::Utc;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tokio::sync::{RwLock, mpsc};
20use tokio::task::JoinHandle;
21
22use crate::domain::ModelFile;
23use crate::ports::{HfClientPort, ModelRepository, RepositoryError};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(tag = "status", rename_all = "snake_case")]
32pub enum ShardProgress {
33 Starting,
35 Hashing {
37 percent: u8,
39 bytes_processed: u64,
41 total_bytes: u64,
43 },
44 Completed {
46 health: ShardHealth,
48 },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(tag = "type", rename_all = "snake_case")]
54pub enum ShardHealth {
55 Healthy,
57 Corrupt {
59 expected: String,
61 actual: String,
63 },
64 Missing,
66 NoOid,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct VerificationProgress {
73 pub model_id: i64,
75 pub shard_index: usize,
77 pub total_shards: usize,
79 pub shard_progress: ShardProgress,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct VerificationReport {
86 pub model_id: i64,
88 pub overall_health: OverallHealth,
90 pub shards: Vec<ShardHealthReport>,
92 pub verified_at: chrono::DateTime<Utc>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98#[serde(rename_all = "snake_case")]
99pub enum OverallHealth {
100 Healthy,
102 Unhealthy,
104 Unverifiable,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ShardHealthReport {
111 pub index: usize,
113 pub file_path: String,
115 pub health: ShardHealth,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct UpdateCheckResult {
122 pub model_id: i64,
124 pub update_available: bool,
126 pub details: Option<UpdateDetails>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UpdateDetails {
133 pub changed_shards: usize,
135 pub changes: Vec<ShardUpdate>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ShardUpdate {
142 pub index: usize,
144 pub file_path: String,
146 pub old_oid: String,
148 pub new_oid: String,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum OperationType {
155 Verifying,
157 Downloading,
159}
160
161pub struct OperationGuard {
167 model_id: i64,
168 lock_map: Arc<RwLock<HashMap<i64, OperationType>>>,
169}
170
171impl Drop for OperationGuard {
172 fn drop(&mut self) {
173 let model_id = self.model_id;
174 let lock_map: Arc<RwLock<HashMap<i64, OperationType>>> = Arc::clone(&self.lock_map);
175
176 tokio::spawn(async move {
178 let mut map = lock_map.write().await;
179 map.remove(&model_id);
180 });
181 }
182}
183
184pub struct ModelOperationLock {
188 locks: Arc<RwLock<HashMap<i64, OperationType>>>,
189}
190
191impl ModelOperationLock {
192 pub fn new() -> Self {
194 Self {
195 locks: Arc::new(RwLock::new(HashMap::new())),
196 }
197 }
198
199 pub async fn try_acquire(
204 &self,
205 model_id: i64,
206 operation: OperationType,
207 ) -> Result<OperationGuard, String> {
208 let mut map = self.locks.write().await;
209
210 if let Some(existing) = map.get(&model_id) {
211 return Err(format!(
212 "Model {model_id} is already locked for {existing:?} operation"
213 ));
214 }
215
216 map.insert(model_id, operation);
217 drop(map);
218
219 Ok(OperationGuard {
220 model_id,
221 lock_map: Arc::clone(&self.locks),
222 })
223 }
224}
225
226impl Default for ModelOperationLock {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[async_trait]
241pub trait ModelFilesReaderPort: Send + Sync {
242 async fn get_by_model_id(&self, model_id: i64) -> anyhow::Result<Vec<ModelFile>>;
244
245 async fn update_verification_time(
247 &self,
248 id: i64,
249 verified_at: chrono::DateTime<Utc>,
250 ) -> anyhow::Result<()>;
251}
252
253#[async_trait]
257pub trait DownloadTriggerPort: Send + Sync {
258 async fn queue_download(
260 &self,
261 repo_id: String,
262 quantization: Option<String>,
263 ) -> anyhow::Result<String>;
264}
265
266pub struct ModelVerificationService {
268 model_repo: Arc<dyn ModelRepository>,
270 model_files_repo: Arc<dyn ModelFilesReaderPort>,
272 hf_client: Arc<dyn HfClientPort>,
274 download_trigger: Arc<dyn DownloadTriggerPort>,
276 operation_lock: ModelOperationLock,
278}
279
280impl ModelVerificationService {
281 pub fn new(
283 model_repo: Arc<dyn ModelRepository>,
284 model_files_repo: Arc<dyn ModelFilesReaderPort>,
285 hf_client: Arc<dyn HfClientPort>,
286 download_trigger: Arc<dyn DownloadTriggerPort>,
287 ) -> Self {
288 Self {
289 model_repo,
290 model_files_repo,
291 hf_client,
292 download_trigger,
293 operation_lock: ModelOperationLock::new(),
294 }
295 }
296
297 pub async fn verify_model_integrity(
310 &self,
311 model_id: i64,
312 ) -> Result<
313 (
314 mpsc::Receiver<VerificationProgress>,
315 JoinHandle<Result<VerificationReport, RepositoryError>>,
316 ),
317 String,
318 > {
319 let guard = self
321 .operation_lock
322 .try_acquire(model_id, OperationType::Verifying)
323 .await?;
324
325 let model = self
327 .model_repo
328 .get_by_id(model_id)
329 .await
330 .map_err(|e| format!("Failed to get model: {e}"))?;
331
332 let model_files = self
333 .model_files_repo
334 .get_by_model_id(model_id)
335 .await
336 .map_err(|e| format!("Failed to get model files: {e}"))?;
337
338 if model_files.is_empty() {
339 return Err("No model files found for verification".to_string());
340 }
341
342 let base_dir = model
344 .file_path
345 .parent()
346 .ok_or_else(|| "Failed to get model directory".to_string())?
347 .to_path_buf();
348
349 let total_shards = model_files.len();
350
351 let (tx, rx) = mpsc::channel(100);
353
354 let model_files_repo = Arc::clone(&self.model_files_repo);
356 let _model_repo = Arc::clone(&self.model_repo);
357
358 let handle = tokio::spawn(async move {
360 let _guard = guard;
362 let mut shard_reports = Vec::new();
363 let mut has_unhealthy = false;
364 let mut has_healthy_or_no_oid = false;
365
366 for (index, file) in model_files.iter().enumerate() {
367 let _ = tx
369 .send(VerificationProgress {
370 model_id,
371 shard_index: index,
372 total_shards,
373 shard_progress: ShardProgress::Starting,
374 })
375 .await;
376
377 let resolved_path = base_dir.join(&file.file_path);
379 let health =
380 Self::verify_shard(file, &resolved_path, model_id, index, total_shards, &tx)
381 .await;
382
383 if let Err(e) = model_files_repo
385 .update_verification_time(file.id, Utc::now())
386 .await
387 {
388 tracing::warn!(
389 model_id = model_id,
390 file_id = file.id,
391 error = %e,
392 "Failed to update verification timestamp"
393 );
394 }
395
396 match &health {
398 ShardHealth::Corrupt { .. } | ShardHealth::Missing => has_unhealthy = true,
399 ShardHealth::Healthy | ShardHealth::NoOid => has_healthy_or_no_oid = true,
400 }
401
402 shard_reports.push(ShardHealthReport {
403 index,
404 file_path: file.file_path.clone(),
405 health: health.clone(),
406 });
407
408 let _ = tx
410 .send(VerificationProgress {
411 model_id,
412 shard_index: index,
413 total_shards,
414 shard_progress: ShardProgress::Completed { health },
415 })
416 .await;
417 }
418
419 let overall_health = if has_unhealthy {
420 OverallHealth::Unhealthy
421 } else if !has_healthy_or_no_oid {
422 OverallHealth::Unverifiable
423 } else {
424 OverallHealth::Healthy
425 };
426
427 Ok(VerificationReport {
428 model_id,
429 overall_health,
430 shards: shard_reports,
431 verified_at: Utc::now(),
432 })
433 });
434
435 Ok((rx, handle))
436 }
437
438 #[allow(clippy::cognitive_complexity)]
440 async fn verify_shard(
441 file: &ModelFile,
442 resolved_path: &Path,
443 model_id: i64,
444 index: usize,
445 total_shards: usize,
446 tx: &mpsc::Sender<VerificationProgress>,
447 ) -> ShardHealth {
448 let Some(ref expected_oid) = file.hf_oid else {
450 return ShardHealth::NoOid;
451 };
452
453 if expected_oid.len() != 64 {
457 tracing::warn!(
458 model_id = model_id,
459 file_path = %file.file_path,
460 oid_len = expected_oid.len(),
461 "Stored OID is not a SHA256 hash (expected 64 hex chars). \
462 Re-download or update model metadata to fix."
463 );
464 return ShardHealth::NoOid;
465 }
466
467 let file_path = resolved_path;
468
469 if !file_path.exists() {
471 return ShardHealth::Missing;
472 }
473
474 let path_owned = file_path.to_path_buf();
476 let tx_clone = tx.clone();
477
478 let result = tokio::task::spawn_blocking(move || -> anyhow::Result<String> {
479 let mut file = File::open(&path_owned)?;
480 let total_bytes = file.metadata()?.len();
481
482 let mut hasher = Sha256::new();
483 let mut buffer = vec![0u8; 1024 * 1024]; let mut bytes_processed = 0u64;
485
486 let _ = tx_clone.blocking_send(VerificationProgress {
488 model_id,
489 shard_index: index,
490 total_shards,
491 shard_progress: ShardProgress::Hashing {
492 percent: 0,
493 bytes_processed: 0,
494 total_bytes,
495 },
496 });
497
498 loop {
499 let n = file.read(&mut buffer)?;
500 if n == 0 {
501 break;
502 }
503
504 hasher.update(&buffer[..n]);
505 bytes_processed += n as u64;
506
507 if bytes_processed % (100 * 1024 * 1024) < (1024 * 1024)
509 || bytes_processed == total_bytes
510 {
511 #[allow(
512 clippy::cast_possible_truncation,
513 clippy::cast_precision_loss,
514 clippy::cast_sign_loss
515 )]
516 let percent = ((bytes_processed as f64 / total_bytes as f64) * 100.0) as u8;
517
518 let _ = tx_clone.blocking_send(VerificationProgress {
519 model_id,
520 shard_index: index,
521 total_shards,
522 shard_progress: ShardProgress::Hashing {
523 percent,
524 bytes_processed,
525 total_bytes,
526 },
527 });
528 }
529 }
530
531 Ok(format!("{:x}", hasher.finalize()))
532 })
533 .await;
534
535 match result {
536 Ok(Ok(computed_hash)) => {
537 if computed_hash == *expected_oid {
538 ShardHealth::Healthy
539 } else {
540 ShardHealth::Corrupt {
541 expected: expected_oid.clone(),
542 actual: computed_hash,
543 }
544 }
545 }
546 Ok(Err(e)) => {
547 tracing::error!(
548 model_id = model_id,
549 file_path = %file.file_path,
550 error = %e,
551 "Failed to compute hash"
552 );
553 ShardHealth::Missing
554 }
555 Err(e) => {
556 tracing::error!(
557 model_id = model_id,
558 file_path = %file.file_path,
559 error = %e,
560 "Task panicked during hash computation"
561 );
562 ShardHealth::Missing
563 }
564 }
565 }
566
567 pub async fn check_for_updates(
571 &self,
572 model_id: i64,
573 ) -> Result<UpdateCheckResult, RepositoryError> {
574 let model = self.model_repo.get_by_id(model_id).await?;
576
577 let Some(ref repo_id) = model.hf_repo_id else {
578 return Ok(UpdateCheckResult {
579 model_id,
580 update_available: false,
581 details: None,
582 });
583 };
584
585 let Some(ref quantization) = model.quantization else {
586 return Ok(UpdateCheckResult {
587 model_id,
588 update_available: false,
589 details: None,
590 });
591 };
592
593 let local_files = self
595 .model_files_repo
596 .get_by_model_id(model_id)
597 .await
598 .map_err(|e| RepositoryError::Storage(e.to_string()))?;
599
600 if local_files.is_empty() {
601 return Ok(UpdateCheckResult {
602 model_id,
603 update_available: false,
604 details: None,
605 });
606 }
607
608 let remote_files = self
610 .hf_client
611 .get_quantization_files(repo_id, quantization)
612 .await
613 .map_err(|e| RepositoryError::Storage(format!("Failed to fetch remote files: {e}")))?;
614
615 let mut changes = Vec::new();
617
618 for local_file in &local_files {
619 let Some(ref local_oid) = local_file.hf_oid else {
620 continue;
621 };
622
623 if let Some(remote_file) = remote_files.iter().find(|f| f.path == local_file.file_path)
625 {
626 if let Some(ref remote_oid) = remote_file.oid {
627 if local_oid != remote_oid {
628 let old_oid_str: String = local_oid.clone();
629 let new_oid_str: String = remote_oid.clone();
630 #[allow(clippy::cast_sign_loss)]
631 let index = local_file.file_index as usize;
632 changes.push(ShardUpdate {
633 index,
634 file_path: local_file.file_path.clone(),
635 old_oid: old_oid_str,
636 new_oid: new_oid_str,
637 });
638 }
639 }
640 }
641 }
642
643 let update_available = !changes.is_empty();
644 let details = if update_available {
645 Some(UpdateDetails {
646 changed_shards: changes.len(),
647 changes,
648 })
649 } else {
650 None
651 };
652
653 Ok(UpdateCheckResult {
654 model_id,
655 update_available,
656 details,
657 })
658 }
659
660 pub async fn repair_model(
668 &self,
669 model_id: i64,
670 shard_indices: Option<Vec<usize>>,
671 ) -> Result<String, String> {
672 let _guard = self
674 .operation_lock
675 .try_acquire(model_id, OperationType::Downloading)
676 .await?;
677
678 let model = self
680 .model_repo
681 .get_by_id(model_id)
682 .await
683 .map_err(|e| format!("Failed to get model: {e}"))?;
684
685 let Some(ref repo_id) = model.hf_repo_id else {
686 return Err("Model does not have HuggingFace repository information".to_string());
687 };
688
689 let Some(ref quantization) = model.quantization else {
690 return Err("Model does not have quantization information".to_string());
691 };
692
693 let model_files = self
695 .model_files_repo
696 .get_by_model_id(model_id)
697 .await
698 .map_err(|e| format!("Failed to get model files: {e}"))?;
699
700 let base_dir = model
702 .file_path
703 .parent()
704 .ok_or_else(|| "Failed to get model directory".to_string())?
705 .to_path_buf();
706
707 let shards_to_repair: Vec<&ModelFile> = if let Some(indices) = shard_indices {
709 #[allow(clippy::cast_sign_loss)]
710 let filter_fn = |f: &&ModelFile| indices.contains(&(f.file_index as usize));
711 model_files.iter().filter(filter_fn).collect()
712 } else {
713 let mut unhealthy = Vec::new();
715 for file in &model_files {
716 let (tx, _rx) = mpsc::channel(1);
717 let resolved_path = base_dir.join(&file.file_path);
718 let health = Self::verify_shard(file, &resolved_path, model_id, 0, 1, &tx).await;
719 match health {
720 ShardHealth::Corrupt { .. } | ShardHealth::Missing => {
721 unhealthy.push(file);
722 }
723 _ => {}
724 }
725 }
726 unhealthy
727 };
728
729 if shards_to_repair.is_empty() {
730 return Err("No unhealthy shards found to repair".to_string());
731 }
732
733 for file in &shards_to_repair {
735 let resolved_path = base_dir.join(&file.file_path);
736 if resolved_path.exists() {
737 if let Err(e) = tokio::fs::remove_file(&resolved_path).await {
738 tracing::warn!(
739 model_id = model_id,
740 file_path = %file.file_path,
741 error = %e,
742 "Failed to delete corrupt file"
743 );
744 }
745 }
746 }
747
748 let download_id = self
750 .download_trigger
751 .queue_download(repo_id.clone(), Some(quantization.clone()))
752 .await
753 .map_err(|e| format!("Failed to queue download: {e}"))?;
754
755 Ok(download_id)
756 }
757}
758
759#[cfg(test)]
760mod tests {
761 use super::*;
762
763 #[tokio::test]
764 async fn test_operation_lock_single_acquire() {
765 let lock = ModelOperationLock::new();
766 let guard = lock.try_acquire(1, OperationType::Verifying).await;
767 assert!(guard.is_ok());
768 }
769
770 #[tokio::test]
771 async fn test_operation_lock_double_acquire_fails() {
772 let lock = ModelOperationLock::new();
773 let _guard1 = lock.try_acquire(1, OperationType::Verifying).await.unwrap();
774 let guard2 = lock.try_acquire(1, OperationType::Downloading).await;
775 assert!(guard2.is_err());
776 }
777
778 #[tokio::test]
779 async fn test_operation_lock_release_on_drop() {
780 let lock = ModelOperationLock::new();
781 {
782 let _guard = lock.try_acquire(1, OperationType::Verifying).await.unwrap();
783 }
784 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
786
787 let guard2 = lock.try_acquire(1, OperationType::Downloading).await;
788 assert!(guard2.is_ok());
789 }
790
791 #[tokio::test]
792 async fn test_operation_lock_different_models() {
793 let lock = ModelOperationLock::new();
794 let guard1 = lock.try_acquire(1, OperationType::Verifying).await;
795 let guard2 = lock.try_acquire(2, OperationType::Verifying).await;
796 assert!(guard1.is_ok());
797 assert!(guard2.is_ok());
798 }
799}