Skip to content

Training API

Complete training pipeline for VAE models with checkpointing and monitoring.

Overview

The training module provides high-level functions for:

  • Complete VAE training workflow
  • Automatic checkpointing
  • Training history visualization
  • Early stopping
  • Learning rate scheduling

Main Training Function

train_vae

The main training function that orchestrates the entire VAE training process.

train_vae

train_vae(
    X_train: ndarray,
    X_test: ndarray,
    y_train: Optional[ndarray] = None,
    y_test: Optional[ndarray] = None,
    config: Optional[VAEConfig] = None,
    save_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]

Train a VAE model with full checkpointing support.

Args: X_train: Training data (samples × features) - numpy array or pandas DataFrame X_test: Test data (samples × features) - numpy array or pandas DataFrame y_train: Optional training labels for CVAE y_test: Optional test labels for CVAE config: Training configuration save_dir: Directory to save checkpoints resume_from: Optional checkpoint path to resume training force_cpu: Force CPU usage even if CUDA is available (for compatibility)

Returns: Tuple of (trained_model, training_history)

Source code in renalprog/modeling/train.py
def train_vae(
    X_train: np.ndarray,
    X_test: np.ndarray,
    y_train: Optional[np.ndarray] = None,
    y_test: Optional[np.ndarray] = None,
    config: Optional[VAEConfig] = None,
    save_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]:
    """Train a VAE model with full checkpointing support.

    Args:
        X_train: Training data (samples × features) - numpy array or pandas DataFrame
        X_test: Test data (samples × features) - numpy array or pandas DataFrame
        y_train: Optional training labels for CVAE
        y_test: Optional test labels for CVAE
        config: Training configuration
        save_dir: Directory to save checkpoints
        resume_from: Optional checkpoint path to resume training
        force_cpu: Force CPU usage even if CUDA is available (for compatibility)

    Returns:
        Tuple of (trained_model, training_history)
    """
    # Convert DataFrames to numpy arrays if needed
    if hasattr(X_train, 'values'):  # Check if it's a DataFrame
        X_train = X_train.values
    if hasattr(X_test, 'values'):  # Check if it's a DataFrame
        X_test = X_test.values
    if y_train is not None and hasattr(y_train, 'values'):
        y_train = y_train.values
    if y_test is not None and hasattr(y_test, 'values'):
        y_test = y_test.values

    if config is None:
        config = VAEConfig()
        config.INPUT_DIM = X_train.shape[1]

    set_seed(config.SEED)

    # Setup save directory
    if save_dir is None:
        timestamp = datetime.now().strftime('%Y%m%d')
        save_dir = Path(f"models/{timestamp}_VAE_KIRC")
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Save config
    save_model_config(config, save_dir / 'config.json')

    # Setup device
    device = get_device(force_cpu=force_cpu)
    logger.info(f"Using device: {device}")

    # Initialize model
    model = VAE(
        input_dim=config.INPUT_DIM,
        mid_dim=config.MID_DIM,
        features=config.LATENT_DIM,
    ).to(device)

    logger.info(f"Model: VAE(input_dim={config.INPUT_DIM}, mid_dim={config.MID_DIM}, latent_dim={config.LATENT_DIM})")
    logger.info(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    # Setup checkpointer
    checkpointer = ModelCheckpointer(
        save_dir=save_dir,
        monitor='val_loss',
        mode='min',
        save_freq=config.CHECKPOINT_FREQ,
        keep_last_n=3,
    )

    # Resume from checkpoint if provided
    start_epoch = 0
    if resume_from is not None:
        checkpoint_info = checkpointer.load_checkpoint(
            resume_from, model, optimizer, device=str(device)
        )
        start_epoch = checkpoint_info['epoch'] + 1
        logger.info(f"Resuming training from epoch {start_epoch}")

    # Create dataloaders
    train_loader = create_dataloader(X_train, y_train, config.BATCH_SIZE, shuffle=True)
    test_loader = create_dataloader(X_test, y_test, config.BATCH_SIZE, shuffle=False)

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_recon_loss': [],
        'train_kl_loss': [],
        'val_recon_loss': [],
        'val_kl_loss': [],
        'beta_schedule': [],  # Track beta values
    }

    # Setup beta annealing schedule
    if config.USE_BETA_ANNEALING:
        beta_schedule = frange_cycle_linear(
            start=config.BETA_START,
            stop=config.BETA,
            n_epoch=config.EPOCHS,
            n_cycle=config.BETA_CYCLES,
            ratio=config.BETA_RATIO
        )
        logger.info(
            f"Using cyclical beta annealing: "
            f"{config.BETA_START} -> {config.BETA} over {config.BETA_CYCLES} cycles"
        )
    else:
        # Constant beta
        beta_schedule = np.ones(config.EPOCHS) * config.BETA
        logger.info(f"Using constant beta: {config.BETA}")

    # Training loop
    logger.info(f"Starting training for {config.EPOCHS} epochs")

    # Add epoch progress bar
    epoch_pbar = tqdm(range(start_epoch, config.EPOCHS), desc='Epochs', position=0)
    for epoch in epoch_pbar:
        # Get beta for this epoch from schedule
        current_beta = beta_schedule[epoch]

        # Train
        train_metrics = train_epoch(model, train_loader, optimizer, device, config, beta=current_beta)

        # Validate
        val_metrics = evaluate_model(model, test_loader, device, config, beta=current_beta)

        # Update history
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_recon_loss'].append(train_metrics['recon_loss'])
        history['train_kl_loss'].append(train_metrics['kl_loss'])
        history['val_recon_loss'].append(val_metrics['recon_loss'])
        history['val_kl_loss'].append(val_metrics['kl_loss'])
        history['beta_schedule'].append(float(current_beta))

        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'train_loss': f"{train_metrics['loss']:.4f}",
            'val_loss': f"{val_metrics['loss']:.4f}",
            'beta': f"{current_beta:.3f}"
        })

        # Log progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            logger.info(
                f"Epoch {epoch+1}/{config.EPOCHS} - "
                f"train_loss: {train_metrics['loss']:.4f}, "
                f"val_loss: {val_metrics['loss']:.4f}"
            )

        # Combine metrics for checkpointing
        current_metrics = {
            'train_loss': train_metrics['loss'],
            'val_loss': val_metrics['loss'],
            'train_recon': train_metrics['recon_loss'],
            'train_kl': train_metrics['kl_loss'],
            'val_recon': val_metrics['recon_loss'],
            'val_kl': val_metrics['kl_loss'],
        }

        # # Save periodic checkpoint
        # if checkpointer.should_save_checkpoint(epoch):
        #     checkpointer.save_checkpoint(
        #         epoch, model, optimizer, current_metrics, config
        #     )

    # Save final model
    checkpointer.save_checkpoint(
        config.EPOCHS - 1, model, optimizer, current_metrics, config, is_final=True
    )

    logger.info("Training complete!")

    return model, history

Key Features

Automatic Checkpointing

The training function automatically saves:

  • Model state dict
  • Optimizer state
  • Training history
  • Configuration parameters

Checkpoints are saved when: - Validation loss improves (best model) - At regular intervals (every checkpoint_every epochs) - After training completes (final model)

Early Stopping

Training stops automatically if validation loss doesn't improve for early_stopping_patience epochs. This prevents overfitting and saves computation time.

Learning Rate Scheduling

When use_scheduler=True, the learning rate is reduced when validation loss plateaus:

scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=10,
    verbose=True
)

Cyclical KL Annealing

The KL divergence weight β is gradually increased using cyclical annealing to prevent posterior collapse:

frange_cycle_linear

frange_cycle_linear(
    start: float,
    stop: float,
    n_epoch: int,
    n_cycle: int = 4,
    ratio: float = 0.5,
) -> np.ndarray

Generate a linear cyclical schedule for beta hyperparameter.

This creates a cyclical annealing schedule where beta increases linearly from start to stop over a portion of each cycle (controlled by ratio), then stays constant at stop for the remainder of the cycle.

Args: start: Initial value of beta (typically 0.0) stop: Final/maximum value of beta (typically 1.0) n_epoch: Total number of epochs n_cycle: Number of cycles (default: 4) ratio: Ratio of cycle spent increasing beta (default: 0.5) - 0.5 means half cycle increasing, half constant - 1.0 means entire cycle increasing

Returns: Array of beta values for each epoch

Example: >>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle >>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5) >>> # Epoch 0-50: beta increases 0.0 -> 1.0 >>> # Epoch 50-100: beta stays at 1.0 >>> # Epoch 100-150: beta increases 0.0 -> 1.0 >>> # Epoch 150-200: beta stays at 1.0 >>> # Epoch 200-250: beta increases 0.0 -> 1.0 >>> # Epoch 250-300: beta stays at 1.0

Source code in renalprog/modeling/train.py
def frange_cycle_linear(
    start: float,
    stop: float,
    n_epoch: int,
    n_cycle: int = 4,
    ratio: float = 0.5
) -> np.ndarray:
    """
    Generate a linear cyclical schedule for beta hyperparameter.

    This creates a cyclical annealing schedule where beta increases linearly
    from start to stop over a portion of each cycle (controlled by ratio),
    then stays constant at stop for the remainder of the cycle.

    Args:
        start: Initial value of beta (typically 0.0)
        stop: Final/maximum value of beta (typically 1.0)
        n_epoch: Total number of epochs
        n_cycle: Number of cycles (default: 4)
        ratio: Ratio of cycle spent increasing beta (default: 0.5)
               - 0.5 means half cycle increasing, half constant
               - 1.0 means entire cycle increasing

    Returns:
        Array of beta values for each epoch

    Example:
        >>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle
        >>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5)
        >>> # Epoch 0-50: beta increases 0.0 -> 1.0
        >>> # Epoch 50-100: beta stays at 1.0
        >>> # Epoch 100-150: beta increases 0.0 -> 1.0
        >>> # Epoch 150-200: beta stays at 1.0
        >>> # Epoch 200-250: beta increases 0.0 -> 1.0
        >>> # Epoch 250-300: beta stays at 1.0
    """
    L = np.ones(n_epoch) * stop  # Initialize all to stop value
    period = n_epoch / n_cycle
    step = (stop - start) / (period * ratio)  # Linear schedule

    for c in range(n_cycle):
        v, i = start, 0
        while v <= stop and (int(i + c * period) < n_epoch):
            L[int(i + c * period)] = v
            v += step
            i += 1

    return L

Training History

The training function returns a dictionary with:

Key Description
train_loss Training loss per epoch
val_loss Validation loss per epoch
train_recon Training reconstruction loss per epoch
val_recon Validation reconstruction loss per epoch
train_kl Training KL divergence per epoch
val_kl Validation KL divergence per epoch
learning_rates Learning rate per epoch

Complete Example

import pandas as pd
from pathlib import Path
from renalprog.modeling.train import train_vae
from renalprog.plots import plot_training_history
from renalprog.utils import set_seed, configure_logging

# Configure
configure_logging()
set_seed(42)

# Load data
train_expr = pd.read_csv("data/interim/split/train_expression.tsv", sep="\t", index_col=0)
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)

# Train VAE
history, best_model, checkpoints = train_vae(
    train_data=train_expr.values,
    val_data=test_expr.values,
    input_dim=train_expr.shape[1],
    mid_dim=1024,
    features=128,
    output_dir=Path("models/my_vae"),
    n_epochs=100,
    batch_size=32,
    learning_rate=1e-3,
    use_scheduler=True,
    use_checkpoint=True,
    checkpoint_every=10,
    early_stopping_patience=20,
    device='cuda'
)

# Plot results
plot_training_history(
    history,
    output_path=Path("reports/figures/training_history.png")
)

# Load best model for inference
best_model.eval()
import torch
with torch.no_grad():
    reconstruction, mu, log_var, z = best_model(
        torch.FloatTensor(test_expr.values).to(device)
    )

print(f"Best validation loss: {min(history['val_loss']):.4f}")
print(f"Final learning rate: {history['learning_rates'][-1]:.6f}")

Checkpointing API

For manual checkpoint management:

checkpointing

Model checkpointing utilities for saving and loading training state.

Classes

ModelCheckpointer

ModelCheckpointer(
    save_dir: Path,
    monitor: str = "val_loss",
    mode: str = "min",
    save_freq: int = 0,
    keep_last_n: int = 3,
)

Handles saving and loading model checkpoints during training.

Features: - Save best model based on validation metric - Save checkpoints every N epochs - Save final model after training - Save training history and configuration - Resume training from checkpoint

Attributes: save_dir: Directory to save checkpoints monitor: Metric to monitor ('loss', 'val_loss', etc.) mode: 'min' for loss, 'max' for accuracy save_freq: Save checkpoint every N epochs (0 = only best) keep_last_n: Keep only last N checkpoints (0 = keep all)

Initialize checkpointer.

Args: save_dir: Directory to save checkpoints monitor: Metric name to monitor mode: 'min' to minimize metric, 'max' to maximize save_freq: Save every N epochs (0 = only save best) keep_last_n: Keep only N most recent checkpoints (0 = all)

Source code in renalprog/modeling/checkpointing.py
def __init__(
    self,
    save_dir: Path,
    monitor: str = 'val_loss',
    mode: str = 'min',
    save_freq: int = 0,
    keep_last_n: int = 3,
):
    """Initialize checkpointer.

    Args:
        save_dir: Directory to save checkpoints
        monitor: Metric name to monitor
        mode: 'min' to minimize metric, 'max' to maximize
        save_freq: Save every N epochs (0 = only save best)
        keep_last_n: Keep only N most recent checkpoints (0 = all)
    """
    self.save_dir = Path(save_dir)
    self.save_dir.mkdir(parents=True, exist_ok=True)

    self.monitor = monitor
    self.mode = mode
    self.save_freq = save_freq
    self.keep_last_n = keep_last_n

    # Track best metric
    self.best_metric = float('inf') if mode == 'min' else float('-inf')
    self.best_epoch = 0

    # Track saved checkpoints for cleanup
    self.checkpoint_history = []

    logger.info(f"ModelCheckpointer initialized: {save_dir}")
    logger.info(f"Monitoring: {monitor} ({mode})")
Functions
get_best_checkpoint_path
get_best_checkpoint_path() -> Optional[Path]

Get path to best model checkpoint.

Returns: Path to best model, or None if not saved yet

Source code in renalprog/modeling/checkpointing.py
def get_best_checkpoint_path(self) -> Optional[Path]:
    """Get path to best model checkpoint.

    Returns:
        Path to best model, or None if not saved yet
    """
    best_path = self.save_dir / 'best_model.pth'
    return best_path if best_path.exists() else None
get_final_checkpoint_path
get_final_checkpoint_path() -> Optional[Path]

Get path to final model checkpoint.

Returns: Path to final model, or None if not saved yet

Source code in renalprog/modeling/checkpointing.py
def get_final_checkpoint_path(self) -> Optional[Path]:
    """Get path to final model checkpoint.

    Returns:
        Path to final model, or None if not saved yet
    """
    final_path = self.save_dir / 'final_model.pth'
    return final_path if final_path.exists() else None
load_checkpoint
load_checkpoint(
    checkpoint_path: Path,
    model: Module,
    optimizer: Optional[Optimizer] = None,
    device: str = "cpu",
) -> Dict[str, Any]

Load a checkpoint and restore model state.

Args: checkpoint_path: Path to checkpoint file model: Model to load state into optimizer: Optional optimizer to restore state device: Device to map checkpoint to

Returns: Dictionary with checkpoint information (epoch, metrics, config)

Source code in renalprog/modeling/checkpointing.py
def load_checkpoint(
    self,
    checkpoint_path: Path,
    model: nn.Module,
    optimizer: Optional[Optimizer] = None,
    device: str = 'cpu',
) -> Dict[str, Any]:
    """Load a checkpoint and restore model state.

    Args:
        checkpoint_path: Path to checkpoint file
        model: Model to load state into
        optimizer: Optional optimizer to restore state
        device: Device to map checkpoint to

    Returns:
        Dictionary with checkpoint information (epoch, metrics, config)
    """
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Restore model state
    model.load_state_dict(checkpoint['model_state_dict'])
    logger.info(f"Loaded model state from epoch {checkpoint['epoch']}")

    # Restore optimizer state if provided
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        logger.info("Loaded optimizer state")

    # Return checkpoint info
    return {
        'epoch': checkpoint['epoch'],
        'metrics': checkpoint.get('metrics', {}),
        'config': checkpoint.get('config', {}),
        'best_metric': checkpoint.get('best_metric', self.best_metric),
    }
save_checkpoint
save_checkpoint(
    epoch: int,
    model: Module,
    optimizer: Optimizer,
    metrics: Dict[str, float],
    config: Any,
    is_best: bool = False,
    is_final: bool = False,
) -> None

Save a training checkpoint.

Args: epoch: Current epoch number model: PyTorch model to save optimizer: Optimizer state to save metrics: Dictionary of current metrics config: Training configuration object is_best: Whether this is the best model so far is_final: Whether this is the final model

Source code in renalprog/modeling/checkpointing.py
def save_checkpoint(
    self,
    epoch: int,
    model: nn.Module,
    optimizer: Optimizer,
    metrics: Dict[str, float],
    config: Any,
    is_best: bool = False,
    is_final: bool = False,
) -> None:
    """Save a training checkpoint.

    Args:
        epoch: Current epoch number
        model: PyTorch model to save
        optimizer: Optimizer state to save
        metrics: Dictionary of current metrics
        config: Training configuration object
        is_best: Whether this is the best model so far
        is_final: Whether this is the final model
    """
    # Create checkpoint dictionary
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics,
        'config': self._config_to_dict(config),
        'best_metric': self.best_metric,
        'monitor': self.monitor,
    }

    # Determine checkpoint filename
    if is_final:
        filename = 'final_model.pth'
    elif is_best:
        filename = 'best_model.pth'
    else:
        filename = f'checkpoint_epoch_{epoch:04d}.pth'

    checkpoint_path = self.save_dir / filename

    # Save checkpoint
    torch.save(checkpoint, checkpoint_path)
    logger.info(f"Saved checkpoint: {checkpoint_path}")

    # Track for cleanup
    if not is_best and not is_final:
        self.checkpoint_history.append(checkpoint_path)
        self._cleanup_old_checkpoints()

    # Save training history as JSON
    if is_best or is_final:
        self._save_history(metrics, epoch)
should_save_checkpoint
should_save_checkpoint(epoch: int) -> bool

Determine if checkpoint should be saved this epoch.

Args: epoch: Current epoch number

Returns: True if checkpoint should be saved

Source code in renalprog/modeling/checkpointing.py
def should_save_checkpoint(self, epoch: int) -> bool:
    """Determine if checkpoint should be saved this epoch.

    Args:
        epoch: Current epoch number

    Returns:
        True if checkpoint should be saved
    """
    if self.save_freq == 0:
        return False
    return epoch % self.save_freq == 0
update_best
update_best(epoch: int, metric_value: float) -> bool

Check if current metric is the best and update if so.

Args: epoch: Current epoch number metric_value: Current metric value

Returns: True if this is a new best, False otherwise

Source code in renalprog/modeling/checkpointing.py
def update_best(self, epoch: int, metric_value: float) -> bool:
    """Check if current metric is the best and update if so.

    Args:
        epoch: Current epoch number
        metric_value: Current metric value

    Returns:
        True if this is a new best, False otherwise
    """
    is_better = (
        (self.mode == 'min' and metric_value < self.best_metric) or
        (self.mode == 'max' and metric_value > self.best_metric)
    )

    if is_better:
        self.best_metric = metric_value
        self.best_epoch = epoch
        logger.info(
            f"New best {self.monitor}: {metric_value:.6f} "
            f"at epoch {epoch}"
        )
        return True
    return False

Functions

load_model_config

load_model_config(config_path: Path) -> Dict[str, Any]

Load model configuration from JSON file.

Args: config_path: Path to JSON config file

Returns: Dictionary with configuration

Source code in renalprog/modeling/checkpointing.py
def load_model_config(config_path: Path) -> Dict[str, Any]:
    """Load model configuration from JSON file.

    Args:
        config_path: Path to JSON config file

    Returns:
        Dictionary with configuration
    """
    if not config_path.exists():
        raise FileNotFoundError(f"Config not found: {config_path}")

    with open(config_path, 'r') as f:
        config = json.load(f)

    logger.info(f"Loaded config: {config_path}")
    return config

save_model_config

save_model_config(config: Any, save_path: Path) -> None

Save model configuration to JSON file.

Args: config: Configuration object save_path: Path to save JSON file

Source code in renalprog/modeling/checkpointing.py
def save_model_config(config: Any, save_path: Path) -> None:
    """Save model configuration to JSON file.

    Args:
        config: Configuration object
        save_path: Path to save JSON file
    """
    config_dict = {}
    if hasattr(config, '__dict__'):
        config_dict = {
            k: v for k, v in config.__dict__.items()
            if not k.startswith('_') and not callable(v)
        }

    save_path.parent.mkdir(parents=True, exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(config_dict, f, indent=2)

    logger.info(f"Saved config: {save_path}")

save_checkpoint

Save model checkpoint with metadata.

save_model_config

save_model_config(config: Any, save_path: Path) -> None

Save model configuration to JSON file.

Args: config: Configuration object save_path: Path to save JSON file

Source code in renalprog/modeling/checkpointing.py
def save_model_config(config: Any, save_path: Path) -> None:
    """Save model configuration to JSON file.

    Args:
        config: Configuration object
        save_path: Path to save JSON file
    """
    config_dict = {}
    if hasattr(config, '__dict__'):
        config_dict = {
            k: v for k, v in config.__dict__.items()
            if not k.startswith('_') and not callable(v)
        }

    save_path.parent.mkdir(parents=True, exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(config_dict, f, indent=2)

    logger.info(f"Saved config: {save_path}")

load_checkpoint

Load model from checkpoint.

load_model_config

load_model_config(config_path: Path) -> Dict[str, Any]

Load model configuration from JSON file.

Args: config_path: Path to JSON config file

Returns: Dictionary with configuration

Source code in renalprog/modeling/checkpointing.py
def load_model_config(config_path: Path) -> Dict[str, Any]:
    """Load model configuration from JSON file.

    Args:
        config_path: Path to JSON config file

    Returns:
        Dictionary with configuration
    """
    if not config_path.exists():
        raise FileNotFoundError(f"Config not found: {config_path}")

    with open(config_path, 'r') as f:
        config = json.load(f)

    logger.info(f"Loaded config: {config_path}")
    return config

See Also