Models API
The modeling module provides neural network architectures and training functions for variational autoencoders (VAEs).
Overview
This module includes:
- VAE architectures (standard, conditional, simple)
- Training and evaluation functions
- Loss functions (reconstruction, KL divergence)
- Checkpoint management
- Post-processing networks
Model Architectures
VAE
Standard Variational Autoencoder with encoder-decoder architecture.
VAE
VAE(input_dim: int, mid_dim: int, features: int, output_layer=nn.ReLU)
Bases: Module
Variational Autoencoder (VAE).
Standard VAE implementation with encoder-decoder architecture and reparameterization trick for sampling from the latent space.
Args: input_dim: Dimension of input data (number of genes) mid_dim: Dimension of hidden layer features: Dimension of latent space output_layer: Output activation function (default: nn.ReLU)
Source code in renalprog/modeling/train.py
| def __init__(self, input_dim: int, mid_dim: int, features: int,
output_layer=nn.ReLU):
super().__init__()
self.input_dim = input_dim
self.mid_dim = mid_dim
self.features = features
self.output_layer = output_layer
# Encoder: input -> mid_dim -> (mu, logvar)
self.encoder = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=features * 2)
)
# Decoder: latent -> mid_dim -> reconstruction
self.decoder = nn.Sequential(
nn.Linear(in_features=features, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=input_dim),
output_layer()
)
|
Functions
forward
forward(
x: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Forward pass through VAE.
Args: x: Input data (batch_size, input_dim)
Returns: Tuple of (reconstruction, mu, log_var, z)
Source code in renalprog/modeling/train.py
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""Forward pass through VAE.
Args:
x: Input data (batch_size, input_dim)
Returns:
Tuple of (reconstruction, mu, log_var, z)
"""
# Encode
encoded = self.encoder(x)
mu_logvar = encoded.view(-1, 2, self.features)
mu = mu_logvar[:, 0, :]
log_var = mu_logvar[:, 1, :]
# Sample from latent space
z = self.reparametrize(mu, log_var)
# Decode
reconstruction = self.decoder(z)
return reconstruction, mu, log_var, z
|
reparametrize
reparametrize(mu: Tensor, log_var: Tensor) -> torch.Tensor
Reparameterization trick: sample from N(mu, var) using N(0,1).
Args: mu: Mean of the latent distribution log_var: Log variance of the latent distribution
Returns: Sampled latent vector
Source code in renalprog/modeling/train.py
| def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""Reparameterization trick: sample from N(mu, var) using N(0,1).
Args:
mu: Mean of the latent distribution
log_var: Log variance of the latent distribution
Returns:
Sampled latent vector
"""
if self.training:
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
else:
# During evaluation, return mean directly
return mu
|
Example Usage:
import torch
from renalprog.modeling.train import VAE
# Create VAE model
model = VAE(
input_dim=20000, # Number of genes
mid_dim=1024, # Hidden layer size
features=128, # Latent dimension
dropout=0.1
)
# Forward pass
x = torch.randn(32, 20000) # Batch of gene expression
reconstruction, mu, log_var, z = model(x)
CVAE
Conditional VAE that incorporates clinical covariates.
CVAE
CVAE(
input_dim: int,
mid_dim: int,
features: int,
num_classes: int,
output_layer=nn.ReLU,
)
Bases: VAE
Conditional Variational Autoencoder.
VAE that conditions on additional information (e.g., clinical data).
Args: input_dim: Dimension of input data mid_dim: Dimension of hidden layer features: Dimension of latent space num_classes: Number of condition classes output_layer: Output activation function
Source code in renalprog/modeling/train.py
| def __init__(self, input_dim: int, mid_dim: int, features: int,
num_classes: int, output_layer=nn.ReLU):
super().__init__(input_dim, mid_dim, features, output_layer)
self.num_classes = num_classes
# Modified encoder: accepts input + condition
self.encoder = nn.Sequential(
nn.Linear(in_features=input_dim + num_classes, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=features * 2)
)
# Modified decoder: accepts latent + condition
self.decoder = nn.Sequential(
nn.Linear(in_features=features + num_classes, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=input_dim),
)
|
Functions
forward
forward(
x: Tensor, condition: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Forward pass through CVAE.
Args: x: Input data condition: Conditioning information (one-hot encoded)
Returns: Tuple of (reconstruction, mu, log_var, z)
Source code in renalprog/modeling/train.py
| def forward(self, x: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""Forward pass through CVAE.
Args:
x: Input data
condition: Conditioning information (one-hot encoded)
Returns:
Tuple of (reconstruction, mu, log_var, z)
"""
# Concatenate input with condition
x_cond = torch.cat([x, condition], dim=1)
# Encode
encoded = self.encoder(x_cond)
mu_logvar = encoded.view(-1, 2, self.features)
mu = mu_logvar[:, 0, :]
log_var = mu_logvar[:, 1, :]
# Sample
z = self.reparametrize(mu, log_var)
# Concatenate latent with condition
z_cond = torch.cat([z, condition], dim=1)
# Decode
reconstruction = self.decoder(z_cond)
reconstruction = self.output_layer()(reconstruction)
return reconstruction, mu, log_var, z
|
Example Usage:
from renalprog.modeling.train import CVAE
# Create conditional VAE
model = ConditionalVAE(
input_dim=20000,
mid_dim=1024,
features=128,
condition_dim=2, # e.g., one-hot encoded stage
dropout=0.1
)
# Forward pass with condition
x = torch.randn(32, 20000)
condition = torch.randn(32, 2) # Clinical covariates
reconstruction, mu, log_var, z = model(x, condition)
AE
Simplified autoencoder without variational component.
AE
AE(input_dim: int, mid_dim: int, features: int, output_layer=nn.ReLU)
Bases: Module
Standard Autoencoder (without variational inference).
Similar architecture to VAE but without reparameterization trick.
Args: input_dim: Dimension of input data mid_dim: Dimension of hidden layer features: Dimension of latent space output_layer: Output activation function
Source code in renalprog/modeling/train.py
| def __init__(self, input_dim: int, mid_dim: int, features: int,
output_layer=nn.ReLU):
super().__init__()
self.input_dim = input_dim
self.mid_dim = mid_dim
self.features = features
self.output_layer = output_layer
self.encoder = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=features)
)
self.decoder = nn.Sequential(
nn.Linear(in_features=features, out_features=mid_dim),
nn.ReLU(),
nn.Linear(in_features=mid_dim, out_features=input_dim),
output_layer()
)
|
Functions
forward
forward(x: Tensor) -> Tuple[torch.Tensor, None, None, torch.Tensor]
Forward pass through AE.
Args: x: Input data
Returns: Tuple of (reconstruction, None, None, z) None values for mu and logvar to maintain consistency with VAE
Source code in renalprog/modeling/train.py
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None, None, torch.Tensor]:
"""Forward pass through AE.
Args:
x: Input data
Returns:
Tuple of (reconstruction, None, None, z)
None values for mu and logvar to maintain consistency with VAE
"""
z = self.encoder(x)
reconstruction = self.decoder(z)
return reconstruction, None, None, z
|
Loss Functions
vae_loss
Complete VAE loss combining reconstruction and KL divergence.
vae_loss
vae_loss(
reconstruction: Tensor,
x: Tensor,
mu: Tensor,
log_var: Tensor,
beta: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Calculate VAE loss: reconstruction loss + KL divergence.
Args: reconstruction: Reconstructed output x: Original input mu: Mean of latent distribution log_var: Log variance of latent distribution beta: Weight for KL divergence term (beta-VAE)
Returns: Tuple of (total_loss, reconstruction_loss, kl_divergence)
Source code in renalprog/modeling/train.py
| def vae_loss(reconstruction: torch.Tensor, x: torch.Tensor,
mu: torch.Tensor, log_var: torch.Tensor,
beta: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate VAE loss: reconstruction loss + KL divergence.
Args:
reconstruction: Reconstructed output
x: Original input
mu: Mean of latent distribution
log_var: Log variance of latent distribution
beta: Weight for KL divergence term (beta-VAE)
Returns:
Tuple of (total_loss, reconstruction_loss, kl_divergence)
"""
# Reconstruction loss (MSE)
recon_loss = nn.functional.mse_loss(reconstruction, x, reduction='sum')
# KL divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Total loss
total_loss = recon_loss + beta * kl_div
return total_loss, recon_loss, kl_div
|
reconstruction_loss
MSE-based reconstruction loss.
reconstruction_loss
reconstruction_loss(
reconstruction: Tensor, x: Tensor, reduction: str = "sum"
) -> torch.Tensor
Calculate reconstruction loss (MSE).
Args: reconstruction: Reconstructed output x: Original input reduction: Reduction method ('sum' or 'mean')
Returns: Reconstruction loss
Source code in renalprog/modeling/train.py
| def reconstruction_loss(reconstruction: torch.Tensor, x: torch.Tensor,
reduction: str = 'sum') -> torch.Tensor:
"""Calculate reconstruction loss (MSE).
Args:
reconstruction: Reconstructed output
x: Original input
reduction: Reduction method ('sum' or 'mean')
Returns:
Reconstruction loss
"""
return nn.functional.mse_loss(reconstruction, x, reduction=reduction)
|
kl_divergence
KL divergence between latent distribution and prior.
kl_divergence
kl_divergence(mu: Tensor, log_var: Tensor) -> torch.Tensor
Calculate KL divergence between approximate posterior and prior.
Args: mu: Mean of approximate posterior log_var: Log variance of approximate posterior
Returns: KL divergence
Source code in renalprog/modeling/train.py
| def kl_divergence(mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""Calculate KL divergence between approximate posterior and prior.
Args:
mu: Mean of approximate posterior
log_var: Log variance of approximate posterior
Returns:
KL divergence
"""
return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
Training Functions
train_vae
Main training function for VAE models.
train_vae
train_vae(
X_train: ndarray,
X_test: ndarray,
y_train: Optional[ndarray] = None,
y_test: Optional[ndarray] = None,
config: Optional[VAEConfig] = None,
save_dir: Optional[Path] = None,
resume_from: Optional[Path] = None,
force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]
Train a VAE model with full checkpointing support.
Args: X_train: Training data (samples × features) - numpy array or pandas DataFrame X_test: Test data (samples × features) - numpy array or pandas DataFrame y_train: Optional training labels for CVAE y_test: Optional test labels for CVAE config: Training configuration save_dir: Directory to save checkpoints resume_from: Optional checkpoint path to resume training force_cpu: Force CPU usage even if CUDA is available (for compatibility)
Returns: Tuple of (trained_model, training_history)
Source code in renalprog/modeling/train.py
| def train_vae(
X_train: np.ndarray,
X_test: np.ndarray,
y_train: Optional[np.ndarray] = None,
y_test: Optional[np.ndarray] = None,
config: Optional[VAEConfig] = None,
save_dir: Optional[Path] = None,
resume_from: Optional[Path] = None,
force_cpu: bool = False,
) -> Tuple[nn.Module, Dict[str, list]]:
"""Train a VAE model with full checkpointing support.
Args:
X_train: Training data (samples × features) - numpy array or pandas DataFrame
X_test: Test data (samples × features) - numpy array or pandas DataFrame
y_train: Optional training labels for CVAE
y_test: Optional test labels for CVAE
config: Training configuration
save_dir: Directory to save checkpoints
resume_from: Optional checkpoint path to resume training
force_cpu: Force CPU usage even if CUDA is available (for compatibility)
Returns:
Tuple of (trained_model, training_history)
"""
# Convert DataFrames to numpy arrays if needed
if hasattr(X_train, 'values'): # Check if it's a DataFrame
X_train = X_train.values
if hasattr(X_test, 'values'): # Check if it's a DataFrame
X_test = X_test.values
if y_train is not None and hasattr(y_train, 'values'):
y_train = y_train.values
if y_test is not None and hasattr(y_test, 'values'):
y_test = y_test.values
if config is None:
config = VAEConfig()
config.INPUT_DIM = X_train.shape[1]
set_seed(config.SEED)
# Setup save directory
if save_dir is None:
timestamp = datetime.now().strftime('%Y%m%d')
save_dir = Path(f"models/{timestamp}_VAE_KIRC")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# Save config
save_model_config(config, save_dir / 'config.json')
# Setup device
device = get_device(force_cpu=force_cpu)
logger.info(f"Using device: {device}")
# Initialize model
model = VAE(
input_dim=config.INPUT_DIM,
mid_dim=config.MID_DIM,
features=config.LATENT_DIM,
).to(device)
logger.info(f"Model: VAE(input_dim={config.INPUT_DIM}, mid_dim={config.MID_DIM}, latent_dim={config.LATENT_DIM})")
logger.info(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup optimizer
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
# Setup checkpointer
checkpointer = ModelCheckpointer(
save_dir=save_dir,
monitor='val_loss',
mode='min',
save_freq=config.CHECKPOINT_FREQ,
keep_last_n=3,
)
# Resume from checkpoint if provided
start_epoch = 0
if resume_from is not None:
checkpoint_info = checkpointer.load_checkpoint(
resume_from, model, optimizer, device=str(device)
)
start_epoch = checkpoint_info['epoch'] + 1
logger.info(f"Resuming training from epoch {start_epoch}")
# Create dataloaders
train_loader = create_dataloader(X_train, y_train, config.BATCH_SIZE, shuffle=True)
test_loader = create_dataloader(X_test, y_test, config.BATCH_SIZE, shuffle=False)
# Training history
history = {
'train_loss': [],
'val_loss': [],
'train_recon_loss': [],
'train_kl_loss': [],
'val_recon_loss': [],
'val_kl_loss': [],
'beta_schedule': [], # Track beta values
}
# Setup beta annealing schedule
if config.USE_BETA_ANNEALING:
beta_schedule = frange_cycle_linear(
start=config.BETA_START,
stop=config.BETA,
n_epoch=config.EPOCHS,
n_cycle=config.BETA_CYCLES,
ratio=config.BETA_RATIO
)
logger.info(
f"Using cyclical beta annealing: "
f"{config.BETA_START} -> {config.BETA} over {config.BETA_CYCLES} cycles"
)
else:
# Constant beta
beta_schedule = np.ones(config.EPOCHS) * config.BETA
logger.info(f"Using constant beta: {config.BETA}")
# Training loop
logger.info(f"Starting training for {config.EPOCHS} epochs")
# Add epoch progress bar
epoch_pbar = tqdm(range(start_epoch, config.EPOCHS), desc='Epochs', position=0)
for epoch in epoch_pbar:
# Get beta for this epoch from schedule
current_beta = beta_schedule[epoch]
# Train
train_metrics = train_epoch(model, train_loader, optimizer, device, config, beta=current_beta)
# Validate
val_metrics = evaluate_model(model, test_loader, device, config, beta=current_beta)
# Update history
history['train_loss'].append(train_metrics['loss'])
history['val_loss'].append(val_metrics['loss'])
history['train_recon_loss'].append(train_metrics['recon_loss'])
history['train_kl_loss'].append(train_metrics['kl_loss'])
history['val_recon_loss'].append(val_metrics['recon_loss'])
history['val_kl_loss'].append(val_metrics['kl_loss'])
history['beta_schedule'].append(float(current_beta))
# Update epoch progress bar
epoch_pbar.set_postfix({
'train_loss': f"{train_metrics['loss']:.4f}",
'val_loss': f"{val_metrics['loss']:.4f}",
'beta': f"{current_beta:.3f}"
})
# Log progress
if (epoch + 1) % 10 == 0 or epoch == 0:
logger.info(
f"Epoch {epoch+1}/{config.EPOCHS} - "
f"train_loss: {train_metrics['loss']:.4f}, "
f"val_loss: {val_metrics['loss']:.4f}"
)
# Combine metrics for checkpointing
current_metrics = {
'train_loss': train_metrics['loss'],
'val_loss': val_metrics['loss'],
'train_recon': train_metrics['recon_loss'],
'train_kl': train_metrics['kl_loss'],
'val_recon': val_metrics['recon_loss'],
'val_kl': val_metrics['kl_loss'],
}
# # Save periodic checkpoint
# if checkpointer.should_save_checkpoint(epoch):
# checkpointer.save_checkpoint(
# epoch, model, optimizer, current_metrics, config
# )
# Save final model
checkpointer.save_checkpoint(
config.EPOCHS - 1, model, optimizer, current_metrics, config, is_final=True
)
logger.info("Training complete!")
return model, history
|
Example Usage:
from renalprog.modeling.train import train_vae
from pathlib import Path
import pandas as pd
# Load training data
train_expr = pd.read_csv("data/interim/split/train_expression.tsv", sep="\t", index_col=0)
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
# Train VAE
history, best_model, checkpoints = train_vae(
train_data=train_expr.values,
val_data=test_expr.values,
input_dim=train_expr.shape[1],
mid_dim=1024,
features=128,
output_dir=Path("models/my_vae"),
n_epochs=100,
batch_size=32,
learning_rate=1e-3,
use_scheduler=True,
use_checkpoint=True,
early_stopping_patience=20
)
print(f"Final validation loss: {history['val_loss'][-1]:.4f}")
train_epoch
Train the model for one epoch.
train_epoch
train_epoch(
model: Module,
dataloader: DataLoader,
optimizer: Optimizer,
device: str,
config: VAEConfig,
beta: Optional[float] = None,
) -> Dict[str, float]
Train model for one epoch.
Args: model: VAE model dataloader: Training DataLoader optimizer: Optimizer device: Device to use config: Training configuration beta: Beta value for this epoch (if None, uses config.BETA)
Returns: Dictionary with loss metrics
Source code in renalprog/modeling/train.py
| def train_epoch(model: nn.Module, dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, device: str,
config: VAEConfig, beta: Optional[float] = None) -> Dict[str, float]:
"""Train model for one epoch.
Args:
model: VAE model
dataloader: Training DataLoader
optimizer: Optimizer
device: Device to use
config: Training configuration
beta: Beta value for this epoch (if None, uses config.BETA)
Returns:
Dictionary with loss metrics
"""
if beta is None:
beta = config.BETA
model.train()
total_loss = 0.0
total_recon = 0.0
total_kl = 0.0
# Add progress bar
pbar = tqdm(dataloader, desc='Training', leave=False)
for batch in pbar:
if len(batch) == 2:
data, _ = batch
else:
data = batch[0]
data = data.to(device)
# Forward pass
optimizer.zero_grad()
reconstruction, mu, log_var, z = model(data)
# Calculate loss (use beta parameter instead of config.BETA)
loss, recon, kl = vae_loss(reconstruction, data, mu, log_var, beta)
# Backward pass
loss.backward()
optimizer.step()
# Accumulate losses
total_loss += loss.item()
total_recon += recon.item()
total_kl += kl.item()
# Update progress bar
pbar.set_postfix({
'loss': f'{loss.item() / len(data):.4f}',
'recon': f'{recon.item() / len(data):.4f}',
'kl': f'{kl.item() / len(data):.4f}'
})
# Average losses
n_samples = len(dataloader.dataset)
metrics = {
'loss': total_loss / n_samples,
'recon_loss': total_recon / n_samples,
'kl_loss': total_kl / n_samples,
}
return metrics
|
evaluate_model
Evaluate model on validation/test data.
evaluate_model
evaluate_model(
model: Module,
dataloader: DataLoader,
device: str,
config: VAEConfig,
beta: Optional[float] = None,
) -> Dict[str, float]
Evaluate model on validation/test set.
Args: model: VAE model dataloader: Validation DataLoader device: Device to use config: Training configuration beta: Beta value for this epoch (if None, uses config.BETA)
Returns: Dictionary with loss metrics
Source code in renalprog/modeling/train.py
| def evaluate_model(model: nn.Module, dataloader: torch.utils.data.DataLoader,
device: str, config: VAEConfig, beta: Optional[float] = None) -> Dict[str, float]:
"""Evaluate model on validation/test set.
Args:
model: VAE model
dataloader: Validation DataLoader
device: Device to use
config: Training configuration
beta: Beta value for this epoch (if None, uses config.BETA)
Returns:
Dictionary with loss metrics
"""
if beta is None:
beta = config.BETA
model.eval()
total_loss = 0.0
total_recon = 0.0
total_kl = 0.0
with torch.no_grad():
# Add progress bar
pbar = tqdm(dataloader, desc='Validation', leave=False)
for batch in pbar:
if len(batch) == 2:
data, _ = batch
else:
data = batch[0]
data = data.to(device)
# Forward pass
reconstruction, mu, log_var, z = model(data)
# Calculate loss (use beta parameter instead of config.BETA)
loss, recon, kl = vae_loss(reconstruction, data, mu, log_var, beta)
# Accumulate losses
total_loss += loss.item()
total_recon += recon.item()
total_kl += kl.item()
# Update progress bar
pbar.set_postfix({
'loss': f'{loss.item() / len(data):.4f}',
'recon': f'{recon.item() / len(data):.4f}',
'kl': f'{kl.item() / len(data):.4f}'
})
# Average losses
n_samples = len(dataloader.dataset)
metrics = {
'loss': total_loss / n_samples,
'recon_loss': total_recon / n_samples,
'kl_loss': total_kl / n_samples,
}
return metrics
|
train_vae_with_postprocessing
Train VAE and post-processing network together.
train_vae_with_postprocessing
train_vae_with_postprocessing(
X_train: ndarray,
X_test: ndarray,
vae_config: Optional[VAEConfig] = None,
reconstruction_network_dims: Optional[List[int]] = None,
reconstruction_epochs: int = 200,
reconstruction_lr: float = 0.0001,
batch_size_reconstruction: int = 8,
save_dir: Optional[Path] = None,
force_cpu: bool = False,
) -> Tuple[nn.Module, nn.Module, Dict[str, list], Dict[str, list]]
Train VAE followed by postprocessing network (full pipeline).
This implements the complete training pipeline as in train_vae.sh: 1. Train VAE on gene expression data 2. Get VAE reconstructions 3. Train NetworkReconstruction to adjust VAE output
Args: X_train: Training data (numpy array or pandas DataFrame) X_test: Test data (numpy array or pandas DataFrame) vae_config: VAE configuration reconstruction_network_dims: Architecture for reconstruction network If None, defaults to [input_dim, 4096, 1024, 4096, input_dim] reconstruction_epochs: Epochs for training reconstruction network reconstruction_lr: Learning rate for reconstruction network save_dir: Directory to save models force_cpu: Force CPU usage
Returns: Tuple of (vae_model, reconstruction_network, vae_history, reconstruction_history)
Source code in renalprog/modeling/train.py
| def train_vae_with_postprocessing(
X_train: np.ndarray,
X_test: np.ndarray,
vae_config: Optional[VAEConfig] = None,
reconstruction_network_dims: Optional[List[int]] = None,
reconstruction_epochs: int = 200,
reconstruction_lr: float = 1e-4,
batch_size_reconstruction:int = 8,
save_dir: Optional[Path] = None,
force_cpu: bool = False,
) -> Tuple[nn.Module, nn.Module, Dict[str, list], Dict[str, list]]:
"""
Train VAE followed by postprocessing network (full pipeline).
This implements the complete training pipeline as in train_vae.sh:
1. Train VAE on gene expression data
2. Get VAE reconstructions
3. Train NetworkReconstruction to adjust VAE output
Args:
X_train: Training data (numpy array or pandas DataFrame)
X_test: Test data (numpy array or pandas DataFrame)
vae_config: VAE configuration
reconstruction_network_dims: Architecture for reconstruction network
If None, defaults to [input_dim, 4096, 1024, 4096, input_dim]
reconstruction_epochs: Epochs for training reconstruction network
reconstruction_lr: Learning rate for reconstruction network
save_dir: Directory to save models
force_cpu: Force CPU usage
Returns:
Tuple of (vae_model, reconstruction_network, vae_history, reconstruction_history)
"""
logger.info("Starting full VAE + postprocessing pipeline")
# Convert DataFrames to numpy arrays if needed
if hasattr(X_train, 'values'): # Check if it's a DataFrame
X_train = X_train.values
if hasattr(X_test, 'values'): # Check if it's a DataFrame
X_test = X_test.values
# Setup
if save_dir is None:
timestamp = datetime.now().strftime('%Y%m%d')
save_dir = Path(f"models/{timestamp}_VAE_with_reconstruction")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# Step 1: Train VAE
logger.info("Step 1: Training VAE")
vae_model, vae_history = train_vae(
X_train, X_test,
config=vae_config,
save_dir=save_dir / "vae",
force_cpu=force_cpu
)
# Step 2: Get VAE reconstructions
logger.info("Step 2: Getting VAE reconstructions")
device = get_device(force_cpu=force_cpu)
vae_model.eval()
# CRITICAL: Normalize data before passing to VAE (same as during training)
# The VAE was trained on normalized [0,1] data, so inference must use the same scale
logger.info("Normalizing data for VAE inference (same as training)")
scaler = MinMaxScaler()
X_train_normalized = scaler.fit_transform(X_train)
X_test_normalized = scaler.transform(X_test)
logger.info(f"Data normalized: min={X_train_normalized.min():.4f}, max={X_train_normalized.max():.4f}")
with torch.no_grad():
X_train_tensor = torch.tensor(X_train_normalized, dtype=torch.float32).to(device)
X_test_tensor = torch.tensor(X_test_normalized, dtype=torch.float32).to(device)
train_recon_normalized, _, _, _ = vae_model(X_train_tensor)
test_recon_normalized, _, _, _ = vae_model(X_test_tensor)
# Denormalize VAE output to match original data scale
train_recon = scaler.inverse_transform(train_recon_normalized.cpu().numpy())
test_recon = scaler.inverse_transform(test_recon_normalized.cpu().numpy())
# Convert to DataFrames
train_indices = [f"train_{i}" for i in range(len(X_train))]
test_indices = [f"test_{i}" for i in range(len(X_test))]
all_recon = np.vstack([train_recon, test_recon])
all_original = np.vstack([X_train, X_test])
all_indices = train_indices + test_indices
df_reconstruction = pd.DataFrame(all_recon, index=all_indices)
df_original = pd.DataFrame(all_original, index=all_indices)
# Step 3: Train reconstruction network
logger.info("Step 3: Training reconstruction network")
input_dim = X_train.shape[1]
if reconstruction_network_dims is None:
reconstruction_network_dims = [input_dim, 4096, 1024, 4096, input_dim]
network = NetworkReconstruction(reconstruction_network_dims)
network, loss_train, loss_test = train_reconstruction_network(
network=network,
vae_reconstructions=df_reconstruction,
original_data=df_original,
train_indices=train_indices,
test_indices=test_indices,
epochs=reconstruction_epochs,
lr=reconstruction_lr,
batch_size=batch_size_reconstruction,
device=str(device)
)
# Step 4: Save everything
logger.info("Step 4: Saving models and results")
# Save reconstruction network
torch.save(network.state_dict(), save_dir / "reconstruction_network.pth")
# Save network dimensions
pd.DataFrame([reconstruction_network_dims],
columns=['in_dim', 'layer1_dim', 'layer2_dim', 'layer3_dim', 'out_dim']
).to_csv(save_dir / "network_dims.csv", index=False)
# Save losses
pd.DataFrame({'train_loss': loss_train, 'test_loss': loss_test}
).to_csv(save_dir / "reconstruction_losses.csv", index=False)
# Plot losses using Plotly
from renalprog.plots import plot_reconstruction_losses
plot_reconstruction_losses(
loss_train, loss_test,
save_path=save_dir / "reconstruction_losses"
)
reconstruction_history = {
'train_loss': loss_train,
'test_loss': loss_test
}
logger.info(f"Full pipeline complete! Models saved to {save_dir}")
return vae_model, network, vae_history, reconstruction_history
|
Utility Functions
create_dataloader
Create PyTorch DataLoader from numpy arrays.
create_dataloader
create_dataloader(
X: ndarray,
y: Optional[ndarray] = None,
batch_size: int = 32,
shuffle: bool = True,
) -> torch.utils.data.DataLoader
Create DataLoader with MinMax normalization.
Args: X: Input data (samples x features) y: Optional labels batch_size: Batch size shuffle: Whether to shuffle data
Returns: DataLoader
Source code in renalprog/modeling/train.py
| def create_dataloader(X: np.ndarray, y: Optional[np.ndarray] = None,
batch_size: int = 32, shuffle: bool = True) -> torch.utils.data.DataLoader:
"""Create DataLoader with MinMax normalization.
Args:
X: Input data (samples x features)
y: Optional labels
batch_size: Batch size
shuffle: Whether to shuffle data
Returns:
DataLoader
"""
# Normalize with MinMaxScaler
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)
# Convert to tensors
X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
if y is not None:
y_tensor = torch.tensor(y, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
else:
dataset = torch.utils.data.TensorDataset(X_tensor)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle
)
return dataloader
|
frange_cycle_linear
Generate cyclical annealing schedule for KL divergence.
frange_cycle_linear
frange_cycle_linear(
start: float,
stop: float,
n_epoch: int,
n_cycle: int = 4,
ratio: float = 0.5,
) -> np.ndarray
Generate a linear cyclical schedule for beta hyperparameter.
This creates a cyclical annealing schedule where beta increases linearly from start to stop over a portion of each cycle (controlled by ratio), then stays constant at stop for the remainder of the cycle.
Args: start: Initial value of beta (typically 0.0) stop: Final/maximum value of beta (typically 1.0) n_epoch: Total number of epochs n_cycle: Number of cycles (default: 4) ratio: Ratio of cycle spent increasing beta (default: 0.5) - 0.5 means half cycle increasing, half constant - 1.0 means entire cycle increasing
Returns: Array of beta values for each epoch
Example: >>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle >>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5) >>> # Epoch 0-50: beta increases 0.0 -> 1.0 >>> # Epoch 50-100: beta stays at 1.0 >>> # Epoch 100-150: beta increases 0.0 -> 1.0 >>> # Epoch 150-200: beta stays at 1.0 >>> # Epoch 200-250: beta increases 0.0 -> 1.0 >>> # Epoch 250-300: beta stays at 1.0
Source code in renalprog/modeling/train.py
| def frange_cycle_linear(
start: float,
stop: float,
n_epoch: int,
n_cycle: int = 4,
ratio: float = 0.5
) -> np.ndarray:
"""
Generate a linear cyclical schedule for beta hyperparameter.
This creates a cyclical annealing schedule where beta increases linearly
from start to stop over a portion of each cycle (controlled by ratio),
then stays constant at stop for the remainder of the cycle.
Args:
start: Initial value of beta (typically 0.0)
stop: Final/maximum value of beta (typically 1.0)
n_epoch: Total number of epochs
n_cycle: Number of cycles (default: 4)
ratio: Ratio of cycle spent increasing beta (default: 0.5)
- 0.5 means half cycle increasing, half constant
- 1.0 means entire cycle increasing
Returns:
Array of beta values for each epoch
Example:
>>> # 3 cycles over 300 epochs, beta increases from 0 to 1 over first half of each cycle
>>> beta_schedule = frange_cycle_linear(0.0, 1.0, 300, n_cycle=3, ratio=0.5)
>>> # Epoch 0-50: beta increases 0.0 -> 1.0
>>> # Epoch 50-100: beta stays at 1.0
>>> # Epoch 100-150: beta increases 0.0 -> 1.0
>>> # Epoch 150-200: beta stays at 1.0
>>> # Epoch 200-250: beta increases 0.0 -> 1.0
>>> # Epoch 250-300: beta stays at 1.0
"""
L = np.ones(n_epoch) * stop # Initialize all to stop value
period = n_epoch / n_cycle
step = (stop - start) / (period * ratio) # Linear schedule
for c in range(n_cycle):
v, i = start, 0
while v <= stop and (int(i + c * period) < n_epoch):
L[int(i + c * period)] = v
v += step
i += 1
return L
|
Example Usage:
from renalprog.modeling.train import frange_cycle_linear
# Create annealing schedule
schedule = frange_cycle_linear(
n_iter=1000,
start=0.0,
stop=1.0,
n_cycle=4,
ratio=0.5
)
# Use in training loop
for i, beta in enumerate(schedule):
loss = reconstruction_loss + beta * kl_loss
Post-Processing Network
NetworkReconstruction
Neural network for refining VAE reconstructions.
NetworkReconstruction
NetworkReconstruction(layer_dims: List[int])
Bases: Module
Deep neural network to adjust VAE reconstruction.
This network is trained on top of VAE output to improve reconstruction quality by learning a mapping from VAE reconstruction to original data.
Args: layer_dims: List of layer dimensions [input_dim, hidden1, hidden2, ..., output_dim]
Source code in renalprog/modeling/train.py
| def __init__(self, layer_dims: List[int]):
super().__init__()
layers = []
for i in range(len(layer_dims) - 1):
layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
if i < len(layer_dims) - 2: # Don't add ReLU after last layer
layers.append(nn.ReLU())
self.network = nn.Sequential(*layers)
|
Functions
forward
forward(x: Tensor) -> torch.Tensor
Forward pass through network.
Source code in renalprog/modeling/train.py
| def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through network."""
return self.network(x)
|
train_reconstruction_network
Train post-processing network.
train_reconstruction_network
train_reconstruction_network(
network: Module,
vae_reconstructions: DataFrame,
original_data: DataFrame,
train_indices: List,
test_indices: List,
epochs: int = 200,
lr: float = 0.0001,
batch_size: int = 32,
device: str = "cpu",
) -> Tuple[nn.Module, List[float], List[float]]
Train reconstruction network to adjust VAE output.
Args: network: NetworkReconstruction model vae_reconstructions: DataFrame with VAE reconstructions (samples x genes) original_data: DataFrame with original gene expression (samples x genes) train_indices: List of training sample indices test_indices: List of test sample indices epochs: Number of training epochs lr: Learning rate batch_size: Batch size device: Device to use
Returns: Tuple of (trained_network, train_losses, test_losses)
Source code in renalprog/modeling/train.py
| def train_reconstruction_network(
network: nn.Module,
vae_reconstructions: pd.DataFrame,
original_data: pd.DataFrame,
train_indices: List,
test_indices: List,
epochs: int = 200,
lr: float = 1e-4,
batch_size: int = 32,
device: str = 'cpu',
) -> Tuple[nn.Module, List[float], List[float]]:
"""
Train reconstruction network to adjust VAE output.
Args:
network: NetworkReconstruction model
vae_reconstructions: DataFrame with VAE reconstructions (samples x genes)
original_data: DataFrame with original gene expression (samples x genes)
train_indices: List of training sample indices
test_indices: List of test sample indices
epochs: Number of training epochs
lr: Learning rate
batch_size: Batch size
device: Device to use
Returns:
Tuple of (trained_network, train_losses, test_losses)
"""
network = network.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(network.parameters(), lr=lr)
# Create dataloaders
train_dataset = torch.utils.data.TensorDataset(
torch.tensor(vae_reconstructions.loc[train_indices].values, dtype=torch.float32),
torch.tensor(original_data.loc[train_indices].values, dtype=torch.float32)
)
test_dataset = torch.utils.data.TensorDataset(
torch.tensor(vae_reconstructions.loc[test_indices].values, dtype=torch.float32),
torch.tensor(original_data.loc[test_indices].values, dtype=torch.float32)
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False
)
loss_train = []
loss_test = []
logger.info(f"Training reconstruction network for {epochs} epochs")
# Add epoch progress bar
epoch_pbar = tqdm(range(epochs), desc='Reconstruction Network Training', position=0)
for epoch in epoch_pbar:
# Training
network.train()
running_loss = 0.0
# Add batch progress bar
train_pbar = tqdm(train_loader, desc='Train', leave=False, position=1)
for vae_recon, original in train_pbar:
vae_recon = vae_recon.to(device)
original = original.to(device)
optimizer.zero_grad()
output = network(vae_recon)
loss = criterion(output, original)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Update batch progress bar
train_pbar.set_postfix({'batch_loss': f'{loss.item():.6f}'})
train_loss = running_loss / len(train_loader)
loss_train.append(train_loss)
# Validation
network.eval()
running_loss = 0.0
with torch.no_grad():
# Add validation batch progress bar
val_pbar = tqdm(test_loader, desc='Val', leave=False, position=1)
for vae_recon, original in val_pbar:
vae_recon = vae_recon.to(device)
original = original.to(device)
output = network(vae_recon)
loss = criterion(output, original)
running_loss += loss.item()
# Update validation progress bar
val_pbar.set_postfix({'batch_loss': f'{loss.item():.6f}'})
test_loss = running_loss / len(test_loader)
loss_test.append(test_loss)
# Update epoch progress bar with current metrics
epoch_pbar.set_postfix({
'train_loss': f'{train_loss:.6f}',
'test_loss': f'{test_loss:.6f}'
})
if (epoch + 1) % 20 == 0:
logger.info(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.6f}, Test Loss: {test_loss:.6f}")
logger.info("Reconstruction network training complete")
return network, loss_train, loss_test
|
See Also