gglib_core/services/
model_verification.rs

1//! Model verification service for integrity checking and update detection.
2//!
3//! This service provides:
4//! - Integrity verification via SHA256 hash comparison against `HuggingFace` OIDs
5//! - Update detection by comparing local OIDs with remote repository state
6//! - Model repair by re-downloading corrupt or missing shards
7//! - Concurrency control to prevent conflicting operations on the same model
8
9use std::collections::HashMap;
10use std::fs::File;
11use std::io::Read;
12use std::path::Path;
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use chrono::Utc;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tokio::sync::{RwLock, mpsc};
20use tokio::task::JoinHandle;
21
22use crate::domain::ModelFile;
23use crate::ports::{HfClientPort, ModelRepository, RepositoryError};
24
25// ============================================================================
26// Domain Types
27// ============================================================================
28
29/// Progress status for an individual shard verification.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(tag = "status", rename_all = "snake_case")]
32pub enum ShardProgress {
33    /// Verification starting for this shard.
34    Starting,
35    /// Currently hashing the file.
36    Hashing {
37        /// Percentage complete (0-100).
38        percent: u8,
39        /// Bytes processed so far.
40        bytes_processed: u64,
41        /// Total bytes in the file.
42        total_bytes: u64,
43    },
44    /// Verification completed for this shard.
45    Completed {
46        /// Health status of this shard.
47        health: ShardHealth,
48    },
49}
50
51/// Health status of an individual shard after verification.
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
53#[serde(tag = "type", rename_all = "snake_case")]
54pub enum ShardHealth {
55    /// File is healthy - hash matches expected OID.
56    Healthy,
57    /// File is corrupt - hash doesn't match expected OID.
58    Corrupt {
59        /// Expected SHA256 hash (from `HuggingFace` OID).
60        expected: String,
61        /// Actual computed SHA256 hash.
62        actual: String,
63    },
64    /// File is missing from disk.
65    Missing,
66    /// No OID available to verify against.
67    NoOid,
68}
69
70/// Progress update during model verification.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct VerificationProgress {
73    /// Model ID being verified.
74    pub model_id: i64,
75    /// Current shard index being verified.
76    pub shard_index: usize,
77    /// Total number of shards.
78    pub total_shards: usize,
79    /// Progress status for this shard.
80    pub shard_progress: ShardProgress,
81}
82
83/// Complete verification report for a model.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct VerificationReport {
86    /// Model ID that was verified.
87    pub model_id: i64,
88    /// Overall health status.
89    pub overall_health: OverallHealth,
90    /// Health status for each shard.
91    pub shards: Vec<ShardHealthReport>,
92    /// When the verification was performed.
93    pub verified_at: chrono::DateTime<Utc>,
94}
95
96/// Overall health status for a model.
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98#[serde(rename_all = "snake_case")]
99pub enum OverallHealth {
100    /// All shards are healthy.
101    Healthy,
102    /// One or more shards are corrupt or missing.
103    Unhealthy,
104    /// No OIDs available for verification.
105    Unverifiable,
106}
107
108/// Health report for a single shard.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ShardHealthReport {
111    /// Shard index.
112    pub index: usize,
113    /// File path.
114    pub file_path: String,
115    /// Health status.
116    pub health: ShardHealth,
117}
118
119/// Result of checking for model updates.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct UpdateCheckResult {
122    /// Model ID that was checked.
123    pub model_id: i64,
124    /// Whether an update is available.
125    pub update_available: bool,
126    /// Details about what changed (if update available).
127    pub details: Option<UpdateDetails>,
128}
129
130/// Details about available updates.
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UpdateDetails {
133    /// Number of shards that have changed.
134    pub changed_shards: usize,
135    /// OID changes per shard.
136    pub changes: Vec<ShardUpdate>,
137}
138
139/// Update information for a single shard.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ShardUpdate {
142    /// Shard index.
143    pub index: usize,
144    /// File path.
145    pub file_path: String,
146    /// Old OID (local).
147    pub old_oid: String,
148    /// New OID (remote).
149    pub new_oid: String,
150}
151
152/// Type of operation being performed on a model.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum OperationType {
155    /// Model is being verified.
156    Verifying,
157    /// Model is being downloaded/repaired.
158    Downloading,
159}
160
161// ============================================================================
162// Concurrency Control
163// ============================================================================
164
165/// RAII guard that automatically releases the operation lock when dropped.
166pub struct OperationGuard {
167    model_id: i64,
168    lock_map: Arc<RwLock<HashMap<i64, OperationType>>>,
169}
170
171impl Drop for OperationGuard {
172    fn drop(&mut self) {
173        let model_id = self.model_id;
174        let lock_map: Arc<RwLock<HashMap<i64, OperationType>>> = Arc::clone(&self.lock_map);
175
176        // Spawn a task to release the lock asynchronously
177        tokio::spawn(async move {
178            let mut map = lock_map.write().await;
179            map.remove(&model_id);
180        });
181    }
182}
183
184/// Concurrency control for model operations.
185///
186/// Ensures only one operation of each type can run on a model at a time.
187pub struct ModelOperationLock {
188    locks: Arc<RwLock<HashMap<i64, OperationType>>>,
189}
190
191impl ModelOperationLock {
192    /// Create a new operation lock manager.
193    pub fn new() -> Self {
194        Self {
195            locks: Arc::new(RwLock::new(HashMap::new())),
196        }
197    }
198
199    /// Try to acquire a lock for the specified operation.
200    ///
201    /// Returns `Ok(guard)` if the lock was acquired, or `Err` if another
202    /// operation is already in progress for this model.
203    pub async fn try_acquire(
204        &self,
205        model_id: i64,
206        operation: OperationType,
207    ) -> Result<OperationGuard, String> {
208        let mut map = self.locks.write().await;
209
210        if let Some(existing) = map.get(&model_id) {
211            return Err(format!(
212                "Model {model_id} is already locked for {existing:?} operation"
213            ));
214        }
215
216        map.insert(model_id, operation);
217        drop(map);
218
219        Ok(OperationGuard {
220            model_id,
221            lock_map: Arc::clone(&self.locks),
222        })
223    }
224}
225
226impl Default for ModelOperationLock {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232// ============================================================================
233// Service
234// ============================================================================
235
236/// Port trait for accessing model files repository.
237///
238/// This is a minimal trait that wraps the concrete `ModelFilesRepository`
239/// to avoid circular dependencies.
240#[async_trait]
241pub trait ModelFilesReaderPort: Send + Sync {
242    /// Get all model files for a specific model.
243    async fn get_by_model_id(&self, model_id: i64) -> anyhow::Result<Vec<ModelFile>>;
244
245    /// Update the last verified timestamp for a model file.
246    async fn update_verification_time(
247        &self,
248        id: i64,
249        verified_at: chrono::DateTime<Utc>,
250    ) -> anyhow::Result<()>;
251}
252
253/// Port trait for triggering downloads.
254///
255/// This abstracts the download manager to avoid tight coupling.
256#[async_trait]
257pub trait DownloadTriggerPort: Send + Sync {
258    /// Queue a download for a specific model by repo ID and quantization.
259    async fn queue_download(
260        &self,
261        repo_id: String,
262        quantization: Option<String>,
263    ) -> anyhow::Result<String>;
264}
265
266/// Model verification service.
267pub struct ModelVerificationService {
268    /// Repository for model metadata.
269    model_repo: Arc<dyn ModelRepository>,
270    /// Repository for model file metadata.
271    model_files_repo: Arc<dyn ModelFilesReaderPort>,
272    /// `HuggingFace` client for update checks.
273    hf_client: Arc<dyn HfClientPort>,
274    /// Download trigger for repairs.
275    download_trigger: Arc<dyn DownloadTriggerPort>,
276    /// Concurrency control.
277    operation_lock: ModelOperationLock,
278}
279
280impl ModelVerificationService {
281    /// Create a new verification service.
282    pub fn new(
283        model_repo: Arc<dyn ModelRepository>,
284        model_files_repo: Arc<dyn ModelFilesReaderPort>,
285        hf_client: Arc<dyn HfClientPort>,
286        download_trigger: Arc<dyn DownloadTriggerPort>,
287    ) -> Self {
288        Self {
289            model_repo,
290            model_files_repo,
291            hf_client,
292            download_trigger,
293            operation_lock: ModelOperationLock::new(),
294        }
295    }
296
297    /// Verify the integrity of a model by computing SHA256 hashes.
298    ///
299    /// Returns a channel for progress updates and a handle to the verification task.
300    ///
301    /// # Arguments
302    ///
303    /// * `model_id` - ID of the model to verify
304    ///
305    /// # Returns
306    ///
307    /// * `receiver` - Channel for receiving progress updates
308    /// * `handle` - Join handle for the verification task
309    pub async fn verify_model_integrity(
310        &self,
311        model_id: i64,
312    ) -> Result<
313        (
314            mpsc::Receiver<VerificationProgress>,
315            JoinHandle<Result<VerificationReport, RepositoryError>>,
316        ),
317        String,
318    > {
319        // Acquire lock
320        let guard = self
321            .operation_lock
322            .try_acquire(model_id, OperationType::Verifying)
323            .await?;
324
325        // Get model and file metadata
326        let model = self
327            .model_repo
328            .get_by_id(model_id)
329            .await
330            .map_err(|e| format!("Failed to get model: {e}"))?;
331
332        let model_files = self
333            .model_files_repo
334            .get_by_model_id(model_id)
335            .await
336            .map_err(|e| format!("Failed to get model files: {e}"))?;
337
338        if model_files.is_empty() {
339            return Err("No model files found for verification".to_string());
340        }
341
342        // Get base directory from model's file path
343        let base_dir = model
344            .file_path
345            .parent()
346            .ok_or_else(|| "Failed to get model directory".to_string())?
347            .to_path_buf();
348
349        let total_shards = model_files.len();
350
351        // Create progress channel
352        let (tx, rx) = mpsc::channel(100);
353
354        // Clone dependencies for the async task
355        let model_files_repo = Arc::clone(&self.model_files_repo);
356        let _model_repo = Arc::clone(&self.model_repo);
357
358        // Spawn verification task
359        let handle = tokio::spawn(async move {
360            // Hold the operation lock for the duration of the verification task
361            let _guard = guard;
362            let mut shard_reports = Vec::new();
363            let mut has_unhealthy = false;
364            let mut has_healthy_or_no_oid = false;
365
366            for (index, file) in model_files.iter().enumerate() {
367                // Send starting progress
368                let _ = tx
369                    .send(VerificationProgress {
370                        model_id,
371                        shard_index: index,
372                        total_shards,
373                        shard_progress: ShardProgress::Starting,
374                    })
375                    .await;
376
377                // Resolve file path relative to base directory
378                let resolved_path = base_dir.join(&file.file_path);
379                let health =
380                    Self::verify_shard(file, &resolved_path, model_id, index, total_shards, &tx)
381                        .await;
382
383                // Update verification timestamp
384                if let Err(e) = model_files_repo
385                    .update_verification_time(file.id, Utc::now())
386                    .await
387                {
388                    tracing::warn!(
389                        model_id = model_id,
390                        file_id = file.id,
391                        error = %e,
392                        "Failed to update verification timestamp"
393                    );
394                }
395
396                // Track overall health
397                match &health {
398                    ShardHealth::Corrupt { .. } | ShardHealth::Missing => has_unhealthy = true,
399                    ShardHealth::Healthy | ShardHealth::NoOid => has_healthy_or_no_oid = true,
400                }
401
402                shard_reports.push(ShardHealthReport {
403                    index,
404                    file_path: file.file_path.clone(),
405                    health: health.clone(),
406                });
407
408                // Send completion progress
409                let _ = tx
410                    .send(VerificationProgress {
411                        model_id,
412                        shard_index: index,
413                        total_shards,
414                        shard_progress: ShardProgress::Completed { health },
415                    })
416                    .await;
417            }
418
419            let overall_health = if has_unhealthy {
420                OverallHealth::Unhealthy
421            } else if !has_healthy_or_no_oid {
422                OverallHealth::Unverifiable
423            } else {
424                OverallHealth::Healthy
425            };
426
427            Ok(VerificationReport {
428                model_id,
429                overall_health,
430                shards: shard_reports,
431                verified_at: Utc::now(),
432            })
433        });
434
435        Ok((rx, handle))
436    }
437
438    /// Verify a single shard by computing its SHA256 and comparing with OID.
439    #[allow(clippy::cognitive_complexity)]
440    async fn verify_shard(
441        file: &ModelFile,
442        resolved_path: &Path,
443        model_id: i64,
444        index: usize,
445        total_shards: usize,
446        tx: &mpsc::Sender<VerificationProgress>,
447    ) -> ShardHealth {
448        // Check if OID is available
449        let Some(ref expected_oid) = file.hf_oid else {
450            return ShardHealth::NoOid;
451        };
452
453        // SHA256 hashes are 64 hex characters. If the stored OID is shorter
454        // (e.g. 40 chars = Git SHA-1), it's the wrong hash type and can't be
455        // used for verification.
456        if expected_oid.len() != 64 {
457            tracing::warn!(
458                model_id = model_id,
459                file_path = %file.file_path,
460                oid_len = expected_oid.len(),
461                "Stored OID is not a SHA256 hash (expected 64 hex chars). \
462                 Re-download or update model metadata to fix."
463            );
464            return ShardHealth::NoOid;
465        }
466
467        let file_path = resolved_path;
468
469        // Check if file exists
470        if !file_path.exists() {
471            return ShardHealth::Missing;
472        }
473
474        // Compute SHA256 in a blocking task
475        let path_owned = file_path.to_path_buf();
476        let tx_clone = tx.clone();
477
478        let result = tokio::task::spawn_blocking(move || -> anyhow::Result<String> {
479            let mut file = File::open(&path_owned)?;
480            let total_bytes = file.metadata()?.len();
481
482            let mut hasher = Sha256::new();
483            let mut buffer = vec![0u8; 1024 * 1024]; // 1MB chunks
484            let mut bytes_processed = 0u64;
485
486            // Initial progress
487            let _ = tx_clone.blocking_send(VerificationProgress {
488                model_id,
489                shard_index: index,
490                total_shards,
491                shard_progress: ShardProgress::Hashing {
492                    percent: 0,
493                    bytes_processed: 0,
494                    total_bytes,
495                },
496            });
497
498            loop {
499                let n = file.read(&mut buffer)?;
500                if n == 0 {
501                    break;
502                }
503
504                hasher.update(&buffer[..n]);
505                bytes_processed += n as u64;
506
507                // Report progress every ~100MB or at end
508                if bytes_processed % (100 * 1024 * 1024) < (1024 * 1024)
509                    || bytes_processed == total_bytes
510                {
511                    #[allow(
512                        clippy::cast_possible_truncation,
513                        clippy::cast_precision_loss,
514                        clippy::cast_sign_loss
515                    )]
516                    let percent = ((bytes_processed as f64 / total_bytes as f64) * 100.0) as u8;
517
518                    let _ = tx_clone.blocking_send(VerificationProgress {
519                        model_id,
520                        shard_index: index,
521                        total_shards,
522                        shard_progress: ShardProgress::Hashing {
523                            percent,
524                            bytes_processed,
525                            total_bytes,
526                        },
527                    });
528                }
529            }
530
531            Ok(format!("{:x}", hasher.finalize()))
532        })
533        .await;
534
535        match result {
536            Ok(Ok(computed_hash)) => {
537                if computed_hash == *expected_oid {
538                    ShardHealth::Healthy
539                } else {
540                    ShardHealth::Corrupt {
541                        expected: expected_oid.clone(),
542                        actual: computed_hash,
543                    }
544                }
545            }
546            Ok(Err(e)) => {
547                tracing::error!(
548                    model_id = model_id,
549                    file_path = %file.file_path,
550                    error = %e,
551                    "Failed to compute hash"
552                );
553                ShardHealth::Missing
554            }
555            Err(e) => {
556                tracing::error!(
557                    model_id = model_id,
558                    file_path = %file.file_path,
559                    error = %e,
560                    "Task panicked during hash computation"
561                );
562                ShardHealth::Missing
563            }
564        }
565    }
566
567    /// Check if updates are available for a model.
568    ///
569    /// Compares local OIDs with remote OIDs from `HuggingFace`.
570    pub async fn check_for_updates(
571        &self,
572        model_id: i64,
573    ) -> Result<UpdateCheckResult, RepositoryError> {
574        // Get model metadata
575        let model = self.model_repo.get_by_id(model_id).await?;
576
577        let Some(ref repo_id) = model.hf_repo_id else {
578            return Ok(UpdateCheckResult {
579                model_id,
580                update_available: false,
581                details: None,
582            });
583        };
584
585        let Some(ref quantization) = model.quantization else {
586            return Ok(UpdateCheckResult {
587                model_id,
588                update_available: false,
589                details: None,
590            });
591        };
592
593        // Get local file metadata
594        let local_files = self
595            .model_files_repo
596            .get_by_model_id(model_id)
597            .await
598            .map_err(|e| RepositoryError::Storage(e.to_string()))?;
599
600        if local_files.is_empty() {
601            return Ok(UpdateCheckResult {
602                model_id,
603                update_available: false,
604                details: None,
605            });
606        }
607
608        // Get remote file metadata from HuggingFace
609        let remote_files = self
610            .hf_client
611            .get_quantization_files(repo_id, quantization)
612            .await
613            .map_err(|e| RepositoryError::Storage(format!("Failed to fetch remote files: {e}")))?;
614
615        // Compare OIDs
616        let mut changes = Vec::new();
617
618        for local_file in &local_files {
619            let Some(ref local_oid) = local_file.hf_oid else {
620                continue;
621            };
622
623            // Find matching remote file by path
624            if let Some(remote_file) = remote_files.iter().find(|f| f.path == local_file.file_path)
625            {
626                if let Some(ref remote_oid) = remote_file.oid {
627                    if local_oid != remote_oid {
628                        let old_oid_str: String = local_oid.clone();
629                        let new_oid_str: String = remote_oid.clone();
630                        #[allow(clippy::cast_sign_loss)]
631                        let index = local_file.file_index as usize;
632                        changes.push(ShardUpdate {
633                            index,
634                            file_path: local_file.file_path.clone(),
635                            old_oid: old_oid_str,
636                            new_oid: new_oid_str,
637                        });
638                    }
639                }
640            }
641        }
642
643        let update_available = !changes.is_empty();
644        let details = if update_available {
645            Some(UpdateDetails {
646                changed_shards: changes.len(),
647                changes,
648            })
649        } else {
650            None
651        };
652
653        Ok(UpdateCheckResult {
654            model_id,
655            update_available,
656            details,
657        })
658    }
659
660    /// Repair a model by re-downloading corrupt or missing shards.
661    ///
662    /// # Arguments
663    ///
664    /// * `model_id` - ID of the model to repair
665    /// * `shard_indices` - Optional list of specific shard indices to repair.
666    ///   If `None`, all unhealthy shards will be repaired.
667    pub async fn repair_model(
668        &self,
669        model_id: i64,
670        shard_indices: Option<Vec<usize>>,
671    ) -> Result<String, String> {
672        // Acquire downloading lock
673        let _guard = self
674            .operation_lock
675            .try_acquire(model_id, OperationType::Downloading)
676            .await?;
677
678        // Get model metadata
679        let model = self
680            .model_repo
681            .get_by_id(model_id)
682            .await
683            .map_err(|e| format!("Failed to get model: {e}"))?;
684
685        let Some(ref repo_id) = model.hf_repo_id else {
686            return Err("Model does not have HuggingFace repository information".to_string());
687        };
688
689        let Some(ref quantization) = model.quantization else {
690            return Err("Model does not have quantization information".to_string());
691        };
692
693        // Get file metadata
694        let model_files = self
695            .model_files_repo
696            .get_by_model_id(model_id)
697            .await
698            .map_err(|e| format!("Failed to get model files: {e}"))?;
699
700        // Get base directory from model's file path
701        let base_dir = model
702            .file_path
703            .parent()
704            .ok_or_else(|| "Failed to get model directory".to_string())?
705            .to_path_buf();
706
707        // Determine which shards to repair
708        let shards_to_repair: Vec<&ModelFile> = if let Some(indices) = shard_indices {
709            #[allow(clippy::cast_sign_loss)]
710            let filter_fn = |f: &&ModelFile| indices.contains(&(f.file_index as usize));
711            model_files.iter().filter(filter_fn).collect()
712        } else {
713            // Verify all shards to find unhealthy ones
714            let mut unhealthy = Vec::new();
715            for file in &model_files {
716                let (tx, _rx) = mpsc::channel(1);
717                let resolved_path = base_dir.join(&file.file_path);
718                let health = Self::verify_shard(file, &resolved_path, model_id, 0, 1, &tx).await;
719                match health {
720                    ShardHealth::Corrupt { .. } | ShardHealth::Missing => {
721                        unhealthy.push(file);
722                    }
723                    _ => {}
724                }
725            }
726            unhealthy
727        };
728
729        if shards_to_repair.is_empty() {
730            return Err("No unhealthy shards found to repair".to_string());
731        }
732
733        // Delete corrupt/missing files
734        for file in &shards_to_repair {
735            let resolved_path = base_dir.join(&file.file_path);
736            if resolved_path.exists() {
737                if let Err(e) = tokio::fs::remove_file(&resolved_path).await {
738                    tracing::warn!(
739                        model_id = model_id,
740                        file_path = %file.file_path,
741                        error = %e,
742                        "Failed to delete corrupt file"
743                    );
744                }
745            }
746        }
747
748        // Trigger re-download
749        let download_id = self
750            .download_trigger
751            .queue_download(repo_id.clone(), Some(quantization.clone()))
752            .await
753            .map_err(|e| format!("Failed to queue download: {e}"))?;
754
755        Ok(download_id)
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use super::*;
762
763    #[tokio::test]
764    async fn test_operation_lock_single_acquire() {
765        let lock = ModelOperationLock::new();
766        let guard = lock.try_acquire(1, OperationType::Verifying).await;
767        assert!(guard.is_ok());
768    }
769
770    #[tokio::test]
771    async fn test_operation_lock_double_acquire_fails() {
772        let lock = ModelOperationLock::new();
773        let _guard1 = lock.try_acquire(1, OperationType::Verifying).await.unwrap();
774        let guard2 = lock.try_acquire(1, OperationType::Downloading).await;
775        assert!(guard2.is_err());
776    }
777
778    #[tokio::test]
779    async fn test_operation_lock_release_on_drop() {
780        let lock = ModelOperationLock::new();
781        {
782            let _guard = lock.try_acquire(1, OperationType::Verifying).await.unwrap();
783        }
784        // Give the drop task time to complete
785        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
786
787        let guard2 = lock.try_acquire(1, OperationType::Downloading).await;
788        assert!(guard2.is_ok());
789    }
790
791    #[tokio::test]
792    async fn test_operation_lock_different_models() {
793        let lock = ModelOperationLock::new();
794        let guard1 = lock.try_acquire(1, OperationType::Verifying).await;
795        let guard2 = lock.try_acquire(2, OperationType::Verifying).await;
796        assert!(guard1.is_ok());
797        assert!(guard2.is_ok());
798    }
799}