Source code for spine.config.download

"""File download utilities for SPINE configuration system.

This module provides utilities for downloading model weights and other files
referenced in YAML configurations.
"""

import fcntl
import hashlib
import os
import sys
import tempfile
import time
import urllib.parse
import urllib.request
import warnings
from pathlib import Path
from typing import Optional
from urllib.error import HTTPError, URLError


[docs] def get_cache_dir() -> Path: """Get the directory for caching downloaded files. By default, creates a '.cache/weights/' directory in SPINE_BASEDIR or SPINE_PROD_BASEDIR (if running from spine-prod). This ensures downloads are cached centrally regardless of execution directory. Can be overridden with the SPINE_CACHE_DIR environment variable. Returns ------- Path Path to cache directory """ cache_dir = os.environ.get("SPINE_CACHE_DIR") if cache_dir: return Path(cache_dir) # Try SPINE_PROD_BASEDIR first (if running from spine-prod) spine_prod_base = os.environ.get("SPINE_PROD_BASEDIR") if spine_prod_base: return Path(spine_prod_base) / ".cache" / "weights" # Fall back to SPINE_BASEDIR spine_base = os.environ.get("SPINE_BASEDIR") if spine_base: return Path(spine_base) / ".cache" / "weights" # Last resort: use current directory (with warning) warnings.warn( "SPINE_BASEDIR and SPINE_PROD_BASEDIR not set. " "Using current directory for cache. " "Please source configure.sh for proper caching.", UserWarning, stacklevel=2, ) return Path.cwd() / "weights"
def compute_file_hash(filepath: Path, algorithm: str = "sha256") -> str: """Compute hash of a file. Parameters ---------- filepath : Path Path to file to hash algorithm : str, optional Hash algorithm to use (default: sha256) Returns ------- str Hex digest of file hash """ hasher = hashlib.new(algorithm) with open(filepath, "rb") as f: while True: chunk = f.read(8192) if not chunk: break hasher.update(chunk) return hasher.hexdigest() def url_to_filename(url: str) -> str: """Convert URL to a safe filename using hash. Preserves the original extension if present. Parameters ---------- url : str URL to convert Returns ------- str Safe filename """ # Extract extension from URL url_path = urllib.parse.urlparse(url).path ext = Path(url_path).suffix # Hash the URL to create a unique filename url_hash = hashlib.sha256(url.encode()).hexdigest()[:16] # Combine hash with extension return f"{url_hash}{ext}" def download_file( url: str, output_path: Path, expected_hash: Optional[str] = None, hash_algorithm: str = "sha256", ) -> None: """Download a file from a URL with progress reporting. Parameters ---------- url : str URL to download from output_path : Path Path where file should be saved expected_hash : str, optional Expected hash of downloaded file for validation hash_algorithm : str, optional Hash algorithm to use for validation (default: sha256) Raises ------ HTTPError If download fails with HTTP error URLError If download fails with URL error ValueError If downloaded file hash doesn't match expected hash """ print(f"Downloading: {url}") print(f"Destination: {output_path}") try: # Download with progress reporting def progress_hook(count, block_size, total_size): if total_size > 0: percent = min(100, count * block_size * 100 // total_size) sys.stdout.write(f"\rProgress: {percent}% ") sys.stdout.flush() urllib.request.urlretrieve(url, output_path, reporthook=progress_hook) print() # New line after progress except HTTPError as e: raise HTTPError( url, e.code, f"HTTP Error {e.code}: {e.reason}", e.hdrs, e.fp ) from e except URLError as e: raise URLError(f"Failed to download {url}: {e.reason}") from e # Validate hash if provided if expected_hash: actual_hash = compute_file_hash(output_path, hash_algorithm) if actual_hash != expected_hash: output_path.unlink() # Remove corrupted file raise ValueError( f"Downloaded file hash mismatch!\n" f"Expected: {expected_hash}\n" f"Got: {actual_hash}\n" f"File has been removed. Please try again." ) print(f"✓ Hash validated: {actual_hash[:16]}...")
[docs] def download_from_url( url: str, expected_hash: Optional[str] = None, cache_dir: Optional[Path] = None, max_wait_seconds: int = 3600, ) -> str: """Download a file from URL and return the cached path. If the file already exists in cache and passes hash validation (if provided), returns the cached path without re-downloading. This function is safe for concurrent access from multiple processes. It uses file locking to ensure only one process downloads at a time, while others wait. Parameters ---------- url : str URL to download from expected_hash : str, optional Expected SHA256 hash of file for validation cache_dir : Path, optional Directory to cache downloaded files (default: from get_cache_dir()) max_wait_seconds : int, optional Maximum time to wait for lock acquisition in seconds (default: 3600) Returns ------- str Absolute path to cached file Raises ------ HTTPError If download fails with HTTP error URLError If download fails with URL error ValueError If hash validation fails TimeoutError If unable to acquire lock within max_wait_seconds Examples -------- >>> path = download_from_url( ... "https://example.com/model.ckpt", ... expected_hash="abc123..." ... ) >>> print(f"Model downloaded to: {path}") """ if cache_dir is None: cache_dir = get_cache_dir() # Create cache directory if it doesn't exist cache_dir.mkdir(parents=True, exist_ok=True) # Generate filename from URL filename = url_to_filename(url) output_path = cache_dir / filename lock_path = cache_dir / f".{filename}.lock" # Quick check without lock (optimization) if output_path.exists() and _validate_cached_file(output_path, expected_hash): return str(output_path.absolute()) # Acquire lock to prevent concurrent downloads lock_file = None try: lock_file = open(lock_path, "w", encoding="utf-8") # Try to acquire lock with timeout start_time = time.time() while True: try: fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) break # Lock acquired except BlockingIOError as exc: # Lock is held by another process elapsed = time.time() - start_time if elapsed > max_wait_seconds: raise TimeoutError( f"Timeout waiting for download lock after {max_wait_seconds}s. " f"Another process may be downloading {url}" ) from exc # Wait and retry if elapsed < 10: # During first 10 seconds, print immediately print( f"Waiting for another process to finish downloading {filename}..." ) elif int(elapsed) % 60 == 0: # After that, print every minute print(f"Still waiting ({int(elapsed)}s elapsed)...") time.sleep(1) # Double-check if file exists (another process may have downloaded it) if output_path.exists() and _validate_cached_file(output_path, expected_hash): return str(output_path.absolute()) # Download to a temporary file first (atomic operation) temp_fd, temp_path = tempfile.mkstemp( dir=cache_dir, prefix=f".{filename}.", suffix=".tmp" ) temp_path = Path(temp_path) try: os.close(temp_fd) # Close fd, we'll use the path # Download the file download_file(url, temp_path, expected_hash) # Atomically move to final location (overwrites if exists) temp_path.replace(output_path) print(f"✓ Download complete: {output_path}") return str(output_path.absolute()) except Exception: # Clean up temp file on error if temp_path.exists(): temp_path.unlink() raise finally: # Release lock and cleanup if lock_file: try: fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) lock_file.close() except OSError: pass # Best effort cleanup # Remove lock file (best effort - may fail if other processes waiting) try: lock_path.unlink() except OSError: pass
def _validate_cached_file( file_path: Path, expected_hash: Optional[str] = None, ) -> bool: """Validate a cached file exists and has correct hash. Parameters ---------- file_path : Path Path to file to validate expected_hash : str, optional Expected hash to validate against Returns ------- bool True if file is valid, False otherwise """ try: if not file_path.exists(): return False # If no hash expected, just check existence if not expected_hash: print(f"✓ Using cached file: {file_path}") return True # Validate hash actual_hash = compute_file_hash(file_path) if actual_hash == expected_hash: print(f"✓ Using cached file: {file_path}") return True else: print("⚠ Cached file hash mismatch, will re-download...") try: file_path.unlink() except OSError: pass # Best effort cleanup return False except OSError as e: print(f"⚠ Error validating cached file: {e}, will re-download...") return False