Skip to content

Models API

The modeling module provides neural network architectures and training functions for variational autoencoders (VAEs).

Overview

This module includes:

  • VAE architectures (standard, conditional, simple)
  • Training and evaluation functions
  • Loss functions (reconstruction, KL divergence)
  • Checkpoint management
  • Post-processing networks

Model Architectures

VAE

Standard Variational Autoencoder with encoder-decoder architecture.

VAE

VAE(input_dim: int, mid_dim: int, features: int, output_layer=nn.ReLU)

Bases: Module

Variational Autoencoder (VAE).

Standard VAE implementation with encoder-decoder architecture and reparameterization trick for sampling from the latent space.

Args: input_dim: Dimension of input data (number of genes) mid_dim: Dimension of hidden layer features: Dimension of latent space output_layer: Output activation function (default: nn.ReLU)

Source code in renalprog/modeling/train.py
def __init__(self, input_dim: int, mid_dim: int, features: int,
             output_layer=nn.ReLU):
    super().__init__()
    self.input_dim = input_dim
    self.mid_dim = mid_dim
    self.features = features
    self.output_layer = output_layer

    # Encoder: input -> mid_dim -> (mu, logvar)
    self.encoder = nn.Sequential(
        nn.Linear(in_features=input_dim, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=features * 2)
    )

    # Decoder: latent -> mid_dim -> reconstruction
    self.decoder = nn.Sequential(
        nn.Linear(in_features=features, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=input_dim),
        output_layer()
    )

Functions

forward

forward(
    x: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass through VAE.

Args: x: Input data (batch_size, input_dim)

Returns: Tuple of (reconstruction, mu, log_var, z)

Source code in renalprog/modeling/train.py
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
                                              torch.Tensor, torch.Tensor]:
    """Forward pass through VAE.

    Args:
        x: Input data (batch_size, input_dim)

    Returns:
        Tuple of (reconstruction, mu, log_var, z)
    """
    # Encode
    encoded = self.encoder(x)
    mu_logvar = encoded.view(-1, 2, self.features)
    mu = mu_logvar[:, 0, :]
    log_var = mu_logvar[:, 1, :]

    # Sample from latent space
    z = self.reparametrize(mu, log_var)

    # Decode
    reconstruction = self.decoder(z)

    return reconstruction, mu, log_var, z

reparametrize

reparametrize(mu: Tensor, log_var: Tensor) -> torch.Tensor

Reparameterization trick: sample from N(mu, var) using N(0,1).

Args: mu: Mean of the latent distribution log_var: Log variance of the latent distribution

Returns: Sampled latent vector

Source code in renalprog/modeling/train.py
def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
    """Reparameterization trick: sample from N(mu, var) using N(0,1).

    Args:
        mu: Mean of the latent distribution
        log_var: Log variance of the latent distribution

    Returns:
        Sampled latent vector
    """
    if self.training:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    else:
        # During evaluation, return mean directly
        return mu

Example Usage:

import torch
from renalprog.modeling.train import VAE

# Create VAE model
model = VAE(
    input_dim=20000,  # Number of genes
    mid_dim=1024,     # Hidden layer size
    features=128,     # Latent dimension
    dropout=0.1
)

# Forward pass
x = torch.randn(32, 20000)  # Batch of gene expression
reconstruction, mu, log_var, z = model(x)

CVAE

Conditional VAE that incorporates clinical covariates.

CVAE

CVAE(
    input_dim: int,
    mid_dim: int,
    features: int,
    num_classes: int,
    output_layer=nn.ReLU,
)

Bases: VAE

Conditional Variational Autoencoder.

VAE that conditions on additional information (e.g., clinical data).

Args: input_dim: Dimension of input data mid_dim: Dimension of hidden layer features: Dimension of latent space num_classes: Number of condition classes output_layer: Output activation function

Source code in renalprog/modeling/train.py
def __init__(self, input_dim: int, mid_dim: int, features: int,
             num_classes: int, output_layer=nn.ReLU):
    super().__init__(input_dim, mid_dim, features, output_layer)
    self.num_classes = num_classes

    # Modified encoder: accepts input + condition
    self.encoder = nn.Sequential(
        nn.Linear(in_features=input_dim + num_classes, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=features * 2)
    )

    # Modified decoder: accepts latent + condition
    self.decoder = nn.Sequential(
        nn.Linear(in_features=features + num_classes, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=input_dim),
    )

Functions

forward

forward(
    x: Tensor, condition: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass through CVAE.

Args: x: Input data condition: Conditioning information (one-hot encoded)

Returns: Tuple of (reconstruction, mu, log_var, z)

Source code in renalprog/modeling/train.py
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor,
                                                                      torch.Tensor,
                                                                      torch.Tensor,
                                                                      torch.Tensor]:
    """Forward pass through CVAE.

    Args:
        x: Input data
        condition: Conditioning information (one-hot encoded)

    Returns:
        Tuple of (reconstruction, mu, log_var, z)
    """
    # Concatenate input with condition
    x_cond = torch.cat([x, condition], dim=1)

    # Encode
    encoded = self.encoder(x_cond)
    mu_logvar = encoded.view(-1, 2, self.features)
    mu = mu_logvar[:, 0, :]
    log_var = mu_logvar[:, 1, :]

    # Sample
    z = self.reparametrize(mu, log_var)

    # Concatenate latent with condition
    z_cond = torch.cat([z, condition], dim=1)

    # Decode
    reconstruction = self.decoder(z_cond)
    reconstruction = self.output_layer()(reconstruction)

    return reconstruction, mu, log_var, z

Example Usage:

from renalprog.modeling.train import CVAE

# Create conditional VAE
model = ConditionalVAE(
    input_dim=20000,
    mid_dim=1024,
    features=128,
    condition_dim=2,  # e.g., one-hot encoded stage
    dropout=0.1
)

# Forward pass with condition
x = torch.randn(32, 20000)
condition = torch.randn(32, 2)  # Clinical covariates
reconstruction, mu, log_var, z = model(x, condition)

AE

Simplified autoencoder without variational component.

AE

AE(input_dim: int, mid_dim: int, features: int, output_layer=nn.ReLU)

Bases: Module

Standard Autoencoder (without variational inference).

Similar architecture to VAE but without reparameterization trick.

Args: input_dim: Dimension of input data mid_dim: Dimension of hidden layer features: Dimension of latent space output_layer: Output activation function

Source code in renalprog/modeling/train.py
def __init__(self, input_dim: int, mid_dim: int, features: int,
             output_layer=nn.ReLU):
    super().__init__()
    self.input_dim = input_dim
    self.mid_dim = mid_dim
    self.features = features
    self.output_layer = output_layer

    self.encoder = nn.Sequential(
        nn.Linear(in_features=input_dim, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=features)
    )

    self.decoder = nn.Sequential(
        nn.Linear(in_features=features, out_features=mid_dim),
        nn.ReLU(),
        nn.Linear(in_features=mid_dim, out_features=input_dim),
        output_layer()
    )

Functions

forward

forward(x: Tensor) -> Tuple[torch.Tensor, None, None, torch.Tensor]

Forward pass through AE.

Args: x: Input data

Returns: Tuple of (reconstruction, None, None, z) None values for mu and logvar to maintain consistency with VAE

Source code in renalprog/modeling/train.py
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None, None, torch.Tensor]:
    """Forward pass through AE.

    Args:
        x: Input data

    Returns:
        Tuple of (reconstruction, None, None, z)
        None values for mu and logvar to maintain consistency with VAE
    """
    z = self.encoder(x)
    reconstruction = self.decoder(z)
    return reconstruction, None, None, z

Loss Functions

vae_loss

Complete VAE loss combining reconstruction and KL divergence.

vae_loss

vae_loss(
    reconstruction: Tensor,
    x: Tensor,
    mu: Tensor,
    log_var: Tensor,
    beta: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Calculate VAE loss: reconstruction loss + KL divergence.

Args: reconstruction: Reconstructed output x: Original input mu: Mean of latent distribution log_var: Log variance of latent distribution beta: Weight for KL divergence term (beta-VAE)

Returns: Tuple of (total_loss, reconstruction_loss, kl_divergence)

Source code in renalprog/modeling/train.py
def vae_loss(reconstruction: torch.Tensor, x: torch.Tensor,
             mu: torch.Tensor, log_var: torch.Tensor,
             beta: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Calculate VAE loss: reconstruction loss + KL divergence.

    Args:
        reconstruction: Reconstructed output
        x: Original input
        mu: Mean of latent distribution
        log_var: Log variance of latent distribution
        beta: Weight for KL divergence term (beta-VAE)

    Returns:
        Tuple of (total_loss, reconstruction_loss, kl_divergence)
    """
    # Reconstruction loss (MSE)
    recon_loss = nn.functional.mse_loss(reconstruction, x, reduction='sum')

    # KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Total loss
    total_loss = recon_loss + beta * kl_div

    return total_loss, recon_loss, kl_div

reconstruction_loss

MSE-based reconstruction loss.

reconstruction_loss

reconstruction_loss(
    reconstruction: Tensor, x: Tensor, reduction: str = "sum"
) -> torch.Tensor

Calculate reconstruction loss (MSE).

Args: reconstruction: Reconstructed output x: Original input reduction: Reduction method ('sum' or 'mean')

Returns: Reconstruction loss

Source code in renalprog/modeling/train.py
def reconstruction_loss(reconstruction: torch.Tensor, x: torch.Tensor,
                       reduction: str = 'sum') -> torch.Tensor:
    """Calculate reconstruction loss (MSE).

    Args:
        reconstruction: Reconstructed output
        x: Original input
        reduction: Reduction method ('sum' or 'mean')

    Returns:
        Reconstruction loss
    """
    return nn.functional.mse_loss(reconstruction, x, reduction=reduction)

kl_divergence

KL divergence between latent distribution and prior.

kl_divergence

kl_divergence(mu: Tensor, log_var: Tensor) -> torch.Tensor

Calculate KL divergence between approximate posterior and prior.

Args: mu: Mean of approximate posterior log_var: Log variance of approximate posterior

Returns: KL divergence

Source code in renalprog/modeling/train.py
def kl_divergence(mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
    """Calculate KL divergence between approximate posterior and prior.

    Args:
        mu: Mean of approximate posterior
        log_var: Log variance of approximate posterior

    Returns:
        KL divergence
    """
    return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

Training Functions

train_vae

Main training function for VAE models.

train_vae

train_vae(
    X_train: ndarray,
    X_test: ndarray,
    y_train: Optional[ndarray] = None,
    y_test: Optional[ndarray] = None,
    config: Optional[VAEConfig] = None,
    save_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]

Train a VAE model with full checkpointing support.

Args: X_train: Training data (samples × features) - numpy array or pandas DataFrame X_test: Test data (samples × features) - numpy array or pandas DataFrame y_train: Optional training labels for CVAE y_test: Optional test labels for CVAE config: Training configuration save_dir: Directory to save checkpoints resume_from: Optional checkpoint path to resume training force_cpu: Force CPU usage even if CUDA is available (for compatibility)

Returns: Tuple of (trained_model, training_history)

Source code in renalprog/modeling/train.py
def train_vae(
    X_train: np.ndarray,
    X_test: np.ndarray,
    y_train: Optional[np.ndarray] = None,
    y_test: Optional[np.ndarray] = None,
    config: Optional[VAEConfig] = None,
    save_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]:
    """Train a VAE model with full checkpointing support.

    Args:
        X_train: Training data (samples × features) - numpy array or pandas DataFrame
        X_test: Test data (samples × features) - numpy array or pandas DataFrame
        y_train: Optional training labels for CVAE
        y_test: Optional test labels for CVAE
        config: Training configuration
        save_dir: Directory to save checkpoints
        resume_from: Optional checkpoint path to resume training
        force_cpu: Force CPU usage even if CUDA is available (for compatibility)

    Returns:
        Tuple of (trained_model, training_history)
    """
    # Convert DataFrames to numpy arrays if needed
    if hasattr(X_train, 'values'):  # Check if it's a DataFrame
        X_train = X_train.values
    if hasattr(X_test, 'values'):  # Check if it's a DataFrame
        X_test = X_test.values
    if y_train is not None and hasattr(y_train, 'values'):
        y_train = y_train.values
    if y_test is not None and hasattr(y_test, 'values'):
        y_test = y_test.values

    if config is None:
        config = VAEConfig()
        config.INPUT_DIM = X_train.shape[1]

    set_seed(config.SEED)

    # Setup save directory
    if save_dir is None:
        timestamp = datetime.now().strftime('%Y%m%d')
        save_dir = Path(f"models/{timestamp}_VAE_KIRC")
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Save config
    save_model_config(config, save_dir / 'config.json')

    # Setup device
    device = get_device(force_cpu=force_cpu)
    logger.info(f"Using device: {device}")

    # Initialize model
    model = VAE(
        input_dim=config.INPUT_DIM,
        mid_dim=config.MID_DIM,
        features=config.LATENT_DIM,
    ).to(device)

    logger.info(f"Model: VAE(input_dim={config.INPUT_DIM}, mid_dim={config.MID_DIM}, latent_dim={config.LATENT_DIM})")
    logger.info(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Setup optimizer
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    # Setup checkpointer
    checkpointer = ModelCheckpointer(
        save_dir=save_dir,
        monitor='val_loss',
        mode='min',
        save_freq=config.CHECKPOINT_FREQ,
        keep_last_n=3,
    )

    # Resume from checkpoint if provided
    start_epoch = 0
    if resume_from is not None:
        checkpoint_info = checkpointer.load_checkpoint(
            resume_from, model, optimizer, device=str(device)
        )
        start_epoch = checkpoint_info['epoch'] + 1
        logger.info(f"Resuming training from epoch {start_epoch}")

    # Create dataloaders
    train_loader = create_dataloader(X_train, y_train, config.BATCH_SIZE, shuffle=True)
    test_loader = create_dataloader(X_test, y_test, config.BATCH_SIZE, shuffle=False)

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_recon_loss': [],
        'train_kl_loss': [],
        'val_recon_loss': [],
        'val_kl_loss': [],
        'beta_schedule': [],  # Track beta values
    }

    # Setup beta annealing schedule
    if config.USE_BETA_ANNEALING:
        beta_schedule = frange_cycle_linear(
            start=config.BETA_START,
            stop=config.BETA,
            n_epoch=config.EPOCHS,
            n_cycle=config.BETA_CYCLES,
            ratio=config.BETA_RATIO
        )
        logger.info(
            f"Using cyclical beta annealing: "
            f"{config.BETA_START} -> {config.BETA} over {config.BETA_CYCLES} cycles"
        )
    else:
        # Constant beta
        beta_schedule = np.ones(config.EPOCHS) * config.BETA
        logger.info(f"Using constant beta: {config.BETA}")

    # Training loop
    logger.info(f"Starting training for {config.EPOCHS} epochs")

    # Add epoch progress bar
    epoch_pbar = tqdm(range(start_epoch, config.EPOCHS), desc='Epochs', position=0)
    for epoch in epoch_pbar:
        # Get beta for this epoch from schedule
        current_beta = beta_schedule[epoch]

        # Train
        train_metrics = train_epoch(model, train_loader, optimizer, device, config, beta=current_beta)

        # Validate
        val_metrics = evaluate_model(model, test_loader, device, config, beta=current_beta)

        # Update history
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_recon_loss'].append(train_metrics['recon_loss'])
        history['train_kl_loss'].append(train_metrics['kl_loss'])
        history['val_recon_loss'].append(val_metrics['recon_loss'])
        history['val_kl_loss'].append(val_metrics['kl_loss'])
        history['beta_schedule'].append(float(current_beta))

        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'train_loss': f"{train_metrics['loss']:.4f}",
            'val_loss': f"{val_metrics['loss']:.4f}",
            'beta': f"{current_beta:.3f}"
        })

        # Log progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            logger.info(
                f"Epoch {epoch+1}/{config.EPOCHS} - "
                f"train_loss: {train_metrics['loss']:.4f}, "
                f"val_loss: {val_metrics['loss']:.4f}"
            )

        # Combine metrics for checkpointing
        current_metrics = {
            'train_loss': train_metrics['loss'],
            'val_loss': val_metrics['loss'],
            'train_recon': train_metrics['recon_loss'],
            'train_kl': train_metrics['kl_loss'],
            'val_recon': val_metrics['recon_loss'],
            'val_kl': val_metrics['kl_loss'],
        }

        # # Save periodic checkpoint
        # if checkpointer.should_save_checkpoint(epoch):
        #     checkpointer.save_checkpoint(
        #         epoch, model, optimizer, current_metrics, config
        #     )

    # Save final model
    checkpointer.save_checkpoint(
        config.EPOCHS - 1, model, optimizer, current_metrics, config, is_final=True
    )

    logger.info("Training complete!")

    return model, history

Example Usage:

from renalprog.modeling.train import train_vae
from pathlib import Path
import pandas as pd

# Load training data
train_expr = pd.read_csv("data/interim/split/train_expression.tsv", sep="\t", index_col=0)
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)

# Train VAE
history, best_model, checkpoints = train_vae(
    train_data=train_expr.values,
    val_data=test_expr.values,
    input_dim=train_expr.shape[1],
    mid_dim=1024,
    features=128,
    output_dir=Path("models/my_vae"),
    n_epochs=100,
    batch_size=32,
    learning_rate=1e-3,
    use_scheduler=True,
    use_checkpoint=True,
    early_stopping_patience=20
)

print(f"Final validation loss: {history['val_loss'][-1]:.4f}")

train_epoch

Train the model for one epoch.

train_epoch

train_epoch(
    model: Module,
    dataloader: DataLoader,
    optimizer: Optimizer,
    device: str,
    config: VAEConfig,
    beta: Optional[float] = None,
) -> Dict[str, float]

Train model for one epoch.

Args: model: VAE model dataloader: Training DataLoader optimizer: Optimizer device: Device to use config: Training configuration beta: Beta value for this epoch (if None, uses config.BETA)

Returns: Dictionary with loss metrics

Source code in renalprog/modeling/train.py
def train_epoch(model: nn.Module, dataloader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer, device: str,
               config: VAEConfig, beta: Optional[float] = None) -> Dict[str, float]:
    """Train model for one epoch.

    Args:
        model: VAE model
        dataloader: Training DataLoader
        optimizer: Optimizer
        device: Device to use
        config: Training configuration
        beta: Beta value for this epoch (if None, uses config.BETA)

    Returns:
        Dictionary with loss metrics
    """
    if beta is None:
        beta = config.BETA
    model.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0

    # Add progress bar
    pbar = tqdm(dataloader, desc='Training', leave=False)
    for batch in pbar:
        if len(batch) == 2:
            data, _ = batch
        else:
            data = batch[0]

        data = data.to(device)

        # Forward pass
        optimizer.zero_grad()
        reconstruction, mu, log_var, z = model(data)

        # Calculate loss (use beta parameter instead of config.BETA)
        loss, recon, kl = vae_loss(reconstruction, data, mu, log_var, beta)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Accumulate losses
        total_loss += loss.item()
        total_recon += recon.item()
        total_kl += kl.item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item() / len(data):.4f}',
            'recon': f'{recon.item() / len(data):.4f}',
            'kl': f'{kl.item() / len(data):.4f}'
        })

    # Average losses
    n_samples = len(dataloader.dataset)
    metrics = {
        'loss': total_loss / n_samples,
        'recon_loss': total_recon / n_samples,
        'kl_loss': total_kl / n_samples,
    }

    return metrics

evaluate_model

Evaluate model on validation/test data.

evaluate_model

evaluate_model(
    model: Module,
    dataloader: DataLoader,
    device: str,
    config: VAEConfig,
    beta: Optional[float] = None,
) -> Dict[str, float]

Evaluate model on validation/test set.

Args: model: VAE model dataloader: Validation DataLoader device: Device to use config: Training configuration beta: Beta value for this epoch (if None, uses config.BETA)

Returns: Dictionary with loss metrics

Source code in renalprog/modeling/train.py
def evaluate_model(model: nn.Module, dataloader: torch.utils.data.DataLoader,
                  device: str, config: VAEConfig, beta: Optional[float] = None) -> Dict[str, float]:
    """Evaluate model on validation/test set.

    Args:
        model: VAE model
        dataloader: Validation DataLoader
        device: Device to use
        config: Training configuration
        beta: Beta value for this epoch (if None, uses config.BETA)

    Returns:
        Dictionary with loss metrics
    """
    if beta is None:
        beta = config.BETA
    model.eval()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0

    with torch.no_grad():
        # Add progress bar
        pbar = tqdm(dataloader, desc='Validation', leave=False)
        for batch in pbar:
            if len(batch) == 2:
                data, _ = batch
            else:
                data = batch[0]

            data = data.to(device)

            # Forward pass
            reconstruction, mu, log_var, z = model(data)

            # Calculate loss (use beta parameter instead of config.BETA)
            loss, recon, kl = vae_loss(reconstruction, data, mu, log_var, beta)

            # Accumulate losses
            total_loss += loss.item()
            total_recon += recon.item()
            total_kl += kl.item()

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item() / len(data):.4f}',
                'recon': f'{recon.item() / len(data):.4f}',
                'kl': f'{kl.item() / len(data):.4f}'
            })

    # Average losses
    n_samples = len(dataloader.dataset)
    metrics = {
        'loss': total_loss / n_samples,
        'recon_loss': total_recon / n_samples,
        'kl_loss': total_kl / n_samples,
    }

    return metrics

train_vae_with_postprocessing

Train VAE and post-processing network together.

train_vae_with_postprocessing

train_vae_with_postprocessing(
    X_train: ndarray,
    X_test: ndarray,
    vae_config: Optional[VAEConfig] = None,
    reconstruction_network_dims: Optional[List[int]] = None,
    reconstruction_epochs: int = 200,
    reconstruction_lr: float = 0.0001,
    batch_size_reconstruction: int = 8,
    save_dir: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, nn.Module, Dict[str, list], Dict[str, list]]

Train VAE followed by postprocessing network (full pipeline).

This implements the complete training pipeline as in train_vae.sh: 1. Train VAE on gene expression data 2. Get VAE reconstructions 3. Train NetworkReconstruction to adjust VAE output

Args: X_train: Training data (numpy array or pandas DataFrame) X_test: Test data (numpy array or pandas DataFrame) vae_config: VAE configuration reconstruction_network_dims: Architecture for reconstruction network If None, defaults to [input_dim, 4096, 1024, 4096, input_dim] reconstruction_epochs: Epochs for training reconstruction network reconstruction_lr: Learning rate for reconstruction network save_dir: Directory to save models force_cpu: Force CPU usage

Returns: Tuple of (vae_model, reconstruction_network, vae_history, reconstruction_history)

Source code in renalprog/modeling/train.py
def train_vae_with_postprocessing(
    X_train: np.ndarray,
    X_test: np.ndarray,
    vae_config: Optional[VAEConfig] = None,
    reconstruction_network_dims: Optional[List[int]] = None,
    reconstruction_epochs: int = 200,
    reconstruction_lr: float = 1e-4,
    batch_size_reconstruction:int = 8,
    save_dir: Optional[Path] = None,
    force_cpu: bool = False,
) -> Tuple[nn.Module, nn.Module, Dict[str, list], Dict[str, list]]:
    """
    Train VAE followed by postprocessing network (full pipeline).

    This implements the complete training pipeline as in train_vae.sh:
    1. Train VAE on gene expression data
    2. Get VAE reconstructions
    3. Train NetworkReconstruction to adjust VAE output

    Args:
        X_train: Training data (numpy array or pandas DataFrame)
        X_test: Test data (numpy array or pandas DataFrame)
        vae_config: VAE configuration
        reconstruction_network_dims: Architecture for reconstruction network
            If None, defaults to [input_dim, 4096, 1024, 4096, input_dim]
        reconstruction_epochs: Epochs for training reconstruction network
        reconstruction_lr: Learning rate for reconstruction network
        save_dir: Directory to save models
        force_cpu: Force CPU usage

    Returns:
        Tuple of (vae_model, reconstruction_network, vae_history, reconstruction_history)
    """
    logger.info("Starting full VAE + postprocessing pipeline")

    # Convert DataFrames to numpy arrays if needed
    if hasattr(X_train, 'values'):  # Check if it's a DataFrame
        X_train = X_train.values
    if hasattr(X_test, 'values'):  # Check if it's a DataFrame
        X_test = X_test.values

    # Setup
    if save_dir is None:
        timestamp = datetime.now().strftime('%Y%m%d')
        save_dir = Path(f"models/{timestamp}_VAE_with_reconstruction")
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Step 1: Train VAE
    logger.info("Step 1: Training VAE")
    vae_model, vae_history = train_vae(
        X_train, X_test,
        config=vae_config,
        save_dir=save_dir / "vae",
        force_cpu=force_cpu
    )

    # Step 2: Get VAE reconstructions
    logger.info("Step 2: Getting VAE reconstructions")
    device = get_device(force_cpu=force_cpu)
    vae_model.eval()

    # CRITICAL: Normalize data before passing to VAE (same as during training)
    # The VAE was trained on normalized [0,1] data, so inference must use the same scale
    logger.info("Normalizing data for VAE inference (same as training)")
    scaler = MinMaxScaler()
    X_train_normalized = scaler.fit_transform(X_train)
    X_test_normalized = scaler.transform(X_test)
    logger.info(f"Data normalized: min={X_train_normalized.min():.4f}, max={X_train_normalized.max():.4f}")

    with torch.no_grad():
        X_train_tensor = torch.tensor(X_train_normalized, dtype=torch.float32).to(device)
        X_test_tensor = torch.tensor(X_test_normalized, dtype=torch.float32).to(device)

        train_recon_normalized, _, _, _ = vae_model(X_train_tensor)
        test_recon_normalized, _, _, _ = vae_model(X_test_tensor)

        # Denormalize VAE output to match original data scale
        train_recon = scaler.inverse_transform(train_recon_normalized.cpu().numpy())
        test_recon = scaler.inverse_transform(test_recon_normalized.cpu().numpy())

    # Convert to DataFrames
    train_indices = [f"train_{i}" for i in range(len(X_train))]
    test_indices = [f"test_{i}" for i in range(len(X_test))]

    all_recon = np.vstack([train_recon, test_recon])
    all_original = np.vstack([X_train, X_test])
    all_indices = train_indices + test_indices

    df_reconstruction = pd.DataFrame(all_recon, index=all_indices)
    df_original = pd.DataFrame(all_original, index=all_indices)

    # Step 3: Train reconstruction network
    logger.info("Step 3: Training reconstruction network")
    input_dim = X_train.shape[1]

    if reconstruction_network_dims is None:
        reconstruction_network_dims = [input_dim, 4096, 1024, 4096, input_dim]

    network = NetworkReconstruction(reconstruction_network_dims)

    network, loss_train, loss_test = train_reconstruction_network(
        network=network,
        vae_reconstructions=df_reconstruction,
        original_data=df_original,
        train_indices=train_indices,
        test_indices=test_indices,
        epochs=reconstruction_epochs,
        lr=reconstruction_lr,
        batch_size=batch_size_reconstruction,
        device=str(device)
    )

    # Step 4: Save everything
    logger.info("Step 4: Saving models and results")

    # Save reconstruction network
    torch.save(network.state_dict(), save_dir / "reconstruction_network.pth")

    # Save network dimensions
    pd.DataFrame([reconstruction_network_dims],
                 columns=['in_dim', 'layer1_dim', 'layer2_dim', 'layer3_dim', 'out_dim']
    ).to_csv(save_dir / "network_dims.csv", index=False)

    # Save losses
    pd.DataFrame({'train_loss': loss_train, 'test_loss': loss_test}
    ).to_csv(save_dir / "reconstruction_losses.csv", index=False)

    # Plot losses using Plotly
    from renalprog.plots import plot_reconstruction_losses
    plot_reconstruction_losses(
        loss_train, loss_test,
        save_path=save_dir / "reconstruction_losses"
    )

    reconstruction_history = {
        'train_loss': loss_train,
        'test_loss': loss_test
    }

    logger.info(f"Full pipeline complete! Models saved to {save_dir}")

    return vae_model, network, vae_history, reconstruction_history

Utility Functions

create_dataloader

Create PyTorch DataLoader from numpy arrays.

create_dataloader

create_dataloader(
    X: ndarray,
    y: Optional[ndarray] = None,
    batch_size: int = 32,
    shuffle: bool = True,
) -> torch.utils.data.DataLoader

Create DataLoader with MinMax normalization.

Args: X: Input data (samples x features) y: Optional labels batch_size: Batch size shuffle: Whether to shuffle data

Returns: DataLoader

Source code in renalprog/modeling/train.py
def create_dataloader(X: np.ndarray, y: Optional[np.ndarray] = None,
                     batch_size: int = 32, shuffle: bool = True) -> torch.utils.data.DataLoader:
    """Create DataLoader with MinMax normalization.

    Args:
        X: Input data (samples x features)
        y: Optional labels
        batch_size: Batch size
        shuffle: Whether to shuffle data

    Returns:
        DataLoader
    """
    # Normalize with MinMaxScaler
    scaler = MinMaxScaler()
    X_scaled = scaler.fit_transform(X)

    # Convert to tensors
    X_tensor = torch.tensor(X_scaled, dtype=torch.float32)

    if y is not None:
        y_tensor = torch.tensor(y, dtype=torch.float32)
        dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
    else:
        dataset = torch.utils.data.TensorDataset(X_tensor)

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

    return dataloader

frange_cycle_linear

Generate cyclical annealing schedule for KL divergence.

frange_cycle_linear

frange_cycle_linear(
    start: float,
    stop: float,
    n_epoch: int,
    n_cycle: int = 4,
    ratio: float = 0.5,
) -> np.ndarray

Generate a linear cyclical schedule for beta hyperparameter.

This creates a cyclical annealing schedule where beta increases linearly from start to stop over a portion of each cycle (controlled by ratio), then stays constant at stop for the remainder of the cycle.

Args: start: Initial value of beta (typically 0.0) stop: Final/maximum value of beta (typically 1.0) n_epoch: Total number of epochs n_cycle: Number of cycles (default: 4) ratio: Ratio of cycle spent increasing beta (default: 0.5) - 0.5 means half cycle increasing, half constant - 1.0 means entire cycle increasing

Returns: Array of beta values for each epoch

Example: >>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle >>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5) >>> # Epoch 0-50: beta increases 0.0 -> 1.0 >>> # Epoch 50-100: beta stays at 1.0 >>> # Epoch 100-150: beta increases 0.0 -> 1.0 >>> # Epoch 150-200: beta stays at 1.0 >>> # Epoch 200-250: beta increases 0.0 -> 1.0 >>> # Epoch 250-300: beta stays at 1.0

Source code in renalprog/modeling/train.py
def frange_cycle_linear(
    start: float,
    stop: float,
    n_epoch: int,
    n_cycle: int = 4,
    ratio: float = 0.5
) -> np.ndarray:
    """
    Generate a linear cyclical schedule for beta hyperparameter.

    This creates a cyclical annealing schedule where beta increases linearly
    from start to stop over a portion of each cycle (controlled by ratio),
    then stays constant at stop for the remainder of the cycle.

    Args:
        start: Initial value of beta (typically 0.0)
        stop: Final/maximum value of beta (typically 1.0)
        n_epoch: Total number of epochs
        n_cycle: Number of cycles (default: 4)
        ratio: Ratio of cycle spent increasing beta (default: 0.5)
               - 0.5 means half cycle increasing, half constant
               - 1.0 means entire cycle increasing

    Returns:
        Array of beta values for each epoch

    Example:
        >>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle
        >>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5)
        >>> # Epoch 0-50: beta increases 0.0 -> 1.0
        >>> # Epoch 50-100: beta stays at 1.0
        >>> # Epoch 100-150: beta increases 0.0 -> 1.0
        >>> # Epoch 150-200: beta stays at 1.0
        >>> # Epoch 200-250: beta increases 0.0 -> 1.0
        >>> # Epoch 250-300: beta stays at 1.0
    """
    L = np.ones(n_epoch) * stop  # Initialize all to stop value
    period = n_epoch / n_cycle
    step = (stop - start) / (period * ratio)  # Linear schedule

    for c in range(n_cycle):
        v, i = start, 0
        while v <= stop and (int(i + c * period) < n_epoch):
            L[int(i + c * period)] = v
            v += step
            i += 1

    return L

Example Usage:

from renalprog.modeling.train import frange_cycle_linear

# Create annealing schedule
schedule = frange_cycle_linear(
    n_iter=1000,
    start=0.0,
    stop=1.0,
    n_cycle=4,
    ratio=0.5
)

# Use in training loop
for i, beta in enumerate(schedule):
    loss = reconstruction_loss + beta * kl_loss

Post-Processing Network

NetworkReconstruction

Neural network for refining VAE reconstructions.

NetworkReconstruction

NetworkReconstruction(layer_dims: List[int])

Bases: Module

Deep neural network to adjust VAE reconstruction.

This network is trained on top of VAE output to improve reconstruction quality by learning a mapping from VAE reconstruction to original data.

Args: layer_dims: List of layer dimensions [input_dim, hidden1, hidden2, ..., output_dim]

Source code in renalprog/modeling/train.py
def __init__(self, layer_dims: List[int]):
    super().__init__()
    layers = []
    for i in range(len(layer_dims) - 1):
        layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
        if i < len(layer_dims) - 2:  # Don't add ReLU after last layer
            layers.append(nn.ReLU())
    self.network = nn.Sequential(*layers)

Functions

forward

forward(x: Tensor) -> torch.Tensor

Forward pass through network.

Source code in renalprog/modeling/train.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through network."""
    return self.network(x)

train_reconstruction_network

Train post-processing network.

train_reconstruction_network

train_reconstruction_network(
    network: Module,
    vae_reconstructions: DataFrame,
    original_data: DataFrame,
    train_indices: List,
    test_indices: List,
    epochs: int = 200,
    lr: float = 0.0001,
    batch_size: int = 32,
    device: str = "cpu",
) -> Tuple[nn.Module, List[float], List[float]]

Train reconstruction network to adjust VAE output.

Args: network: NetworkReconstruction model vae_reconstructions: DataFrame with VAE reconstructions (samples x genes) original_data: DataFrame with original gene expression (samples x genes) train_indices: List of training sample indices test_indices: List of test sample indices epochs: Number of training epochs lr: Learning rate batch_size: Batch size device: Device to use

Returns: Tuple of (trained_network, train_losses, test_losses)

Source code in renalprog/modeling/train.py
def train_reconstruction_network(
    network: nn.Module,
    vae_reconstructions: pd.DataFrame,
    original_data: pd.DataFrame,
    train_indices: List,
    test_indices: List,
    epochs: int = 200,
    lr: float = 1e-4,
    batch_size: int = 32,
    device: str = 'cpu',
) -> Tuple[nn.Module, List[float], List[float]]:
    """
    Train reconstruction network to adjust VAE output.

    Args:
        network: NetworkReconstruction model
        vae_reconstructions: DataFrame with VAE reconstructions (samples x genes)
        original_data: DataFrame with original gene expression (samples x genes)
        train_indices: List of training sample indices
        test_indices: List of test sample indices
        epochs: Number of training epochs
        lr: Learning rate
        batch_size: Batch size
        device: Device to use

    Returns:
        Tuple of (trained_network, train_losses, test_losses)
    """
    network = network.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(network.parameters(), lr=lr)

    # Create dataloaders
    train_dataset = torch.utils.data.TensorDataset(
        torch.tensor(vae_reconstructions.loc[train_indices].values, dtype=torch.float32),
        torch.tensor(original_data.loc[train_indices].values, dtype=torch.float32)
    )
    test_dataset = torch.utils.data.TensorDataset(
        torch.tensor(vae_reconstructions.loc[test_indices].values, dtype=torch.float32),
        torch.tensor(original_data.loc[test_indices].values, dtype=torch.float32)
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )

    loss_train = []
    loss_test = []

    logger.info(f"Training reconstruction network for {epochs} epochs")

    # Add epoch progress bar
    epoch_pbar = tqdm(range(epochs), desc='Reconstruction Network Training', position=0)
    for epoch in epoch_pbar:
        # Training
        network.train()
        running_loss = 0.0

        # Add batch progress bar
        train_pbar = tqdm(train_loader, desc='Train', leave=False, position=1)
        for vae_recon, original in train_pbar:
            vae_recon = vae_recon.to(device)
            original = original.to(device)

            optimizer.zero_grad()
            output = network(vae_recon)
            loss = criterion(output, original)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Update batch progress bar
            train_pbar.set_postfix({'batch_loss': f'{loss.item():.6f}'})

        train_loss = running_loss / len(train_loader)
        loss_train.append(train_loss)

        # Validation
        network.eval()
        running_loss = 0.0
        with torch.no_grad():
            # Add validation batch progress bar
            val_pbar = tqdm(test_loader, desc='Val', leave=False, position=1)
            for vae_recon, original in val_pbar:
                vae_recon = vae_recon.to(device)
                original = original.to(device)

                output = network(vae_recon)
                loss = criterion(output, original)
                running_loss += loss.item()

                # Update validation progress bar
                val_pbar.set_postfix({'batch_loss': f'{loss.item():.6f}'})

        test_loss = running_loss / len(test_loader)
        loss_test.append(test_loss)

        # Update epoch progress bar with current metrics
        epoch_pbar.set_postfix({
            'train_loss': f'{train_loss:.6f}',
            'test_loss': f'{test_loss:.6f}'
        })

        if (epoch + 1) % 20 == 0:
            logger.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Test Loss: {test_loss:.6f}")

    logger.info("Reconstruction network training complete")
    return network, loss_train, loss_test

See Also