Prediction API
Functions for applying trained VAE models to generate latent representations and trajectories.
Overview
This module provides:
- Apply trained VAE to encode data
- Generate disease progression trajectories
- Evaluate reconstruction quality
- Patient connectivity analysis
- Latent space interpolation
Core Prediction Functions
apply_vae
Apply trained VAE model to encode gene expression data into latent space.
apply_vae
apply_vae(
model: Module, data: DataFrame, device: str = "cpu"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
Apply VAE to data to get reconstruction and latent representation.
Args: model: Trained VAE model data: Input data (samples x genes) device: Device to run inference on
Returns: Tuple of (reconstruction, mu, logvar, z)
Source code in renalprog/modeling/predict.py
| def apply_vae(
model: torch.nn.Module,
data: pd.DataFrame,
device: str = "cpu"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Apply VAE to data to get reconstruction and latent representation.
Args:
model: Trained VAE model
data: Input data (samples x genes)
device: Device to run inference on
Returns:
Tuple of (reconstruction, mu, logvar, z)
"""
model.eval()
model = model.to(device)
# Convert to tensor
if isinstance(data, pd.DataFrame):
data_tensor = torch.tensor(data.values, dtype=torch.float32).to(device)
else:
data_tensor = torch.tensor(data, dtype=torch.float32).to(device)
with torch.no_grad():
reconstruction, mu, logvar, z = model(data_tensor)
# Convert back to numpy
reconstruction = reconstruction.cpu().numpy()
mu = mu.cpu().numpy()
logvar = logvar.cpu().numpy() if logvar is not None else None
z = z.cpu().numpy()
return reconstruction, mu, logvar, z
|
Example Usage:
import torch
import pandas as pd
from pathlib import Path
from renalprog.modeling.train import VAE
from renalprog.modeling.predict import apply_vae
# Load model
model = VAE(input_dim=20000, mid_dim=1024, features=128)
model.load_state_dict(torch.load("models/my_vae/best_model.pt"))
# Load test data
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
# Apply VAE
results = apply_vae(
model=model,
data=test_expr.values,
device='cuda',
batch_size=32
)
latent = results['latent'] # Latent representations
reconstructed = results['reconstructed'] # Reconstructed expression
print(f"Latent space shape: {latent.shape}")
Trajectory Generation
generate_trajectories
Generate disease progression trajectories by interpolating in latent space.
generate_trajectories
generate_trajectories(
model: Module,
source_samples: DataFrame,
target_samples: DataFrame,
n_steps: int = 50,
method: str = "linear",
output_dir: Optional[Path] = None,
parallel: bool = False,
n_workers: Optional[int] = None,
) -> Dict[str, pd.DataFrame]
Generate synthetic trajectories between source and target patient samples.
This function creates interpolated gene expression profiles in the latent space between pairs of patients at different cancer stages.
Args: model: Trained VAE model source_samples: Source patient samples (early stage) target_samples: Target patient samples (late stage) n_steps: Number of interpolation steps method: Interpolation method ("linear" or "spherical") output_dir: Optional directory to save trajectories parallel: Whether to use parallel processing n_workers: Number of parallel workers (None = use all CPUs)
Returns: Dictionary mapping patient pairs to trajectory DataFrames
Source code in renalprog/modeling/predict.py
| def generate_trajectories(
model: torch.nn.Module,
source_samples: pd.DataFrame,
target_samples: pd.DataFrame,
n_steps: int = 50,
method: str = "linear",
output_dir: Optional[Path] = None,
parallel: bool = False,
n_workers: Optional[int] = None
) -> Dict[str, pd.DataFrame]:
"""
Generate synthetic trajectories between source and target patient samples.
This function creates interpolated gene expression profiles in the latent
space between pairs of patients at different cancer stages.
Args:
model: Trained VAE model
source_samples: Source patient samples (early stage)
target_samples: Target patient samples (late stage)
n_steps: Number of interpolation steps
method: Interpolation method ("linear" or "spherical")
output_dir: Optional directory to save trajectories
parallel: Whether to use parallel processing
n_workers: Number of parallel workers (None = use all CPUs)
Returns:
Dictionary mapping patient pairs to trajectory DataFrames
"""
logger.info(f"Generating trajectories with {n_steps} steps using {method} interpolation")
# TODO: Implement trajectory generation
# Migrate from src_deseq_and_gsea_NCSR/synthetic_data_generation.py
raise NotImplementedError(
"generate_trajectories() needs implementation from "
"src_deseq_and_gsea_NCSR/synthetic_data_generation.py and "
"src/data/fun_interpol.py"
)
|
Example Usage:
from renalprog.modeling.predict import generate_trajectories
# Generate trajectories from early to late stage
trajectories = generate_trajectories(
model=model,
start_data=early_stage_samples.values,
end_data=late_stage_samples.values,
n_steps=50,
interpolation='spherical',
device='cuda'
)
# trajectories shape: (n_samples, n_steps, n_genes)
print(f"Generated {trajectories.shape[0]} trajectories")
print(f"Each with {trajectories.shape[1]} steps")
create_patient_connections
Create optimal patient pairings for trajectory generation.
create_patient_connections
create_patient_connections(
data: DataFrame,
clinical: Series,
method: str = "random",
transition_type: str = "early_to_late",
n_connections: Optional[int] = None,
seed: int = 2023,
) -> pd.DataFrame
Create connections between patients for trajectory generation.
Args: data: Gene expression data clinical: Clinical stage information method: Method for creating connections ("random", "nearest", "all") transition_type: Type of transition ("early_to_late", "early_to_early", etc.) n_connections: Number of connections to create (None = all possible) seed: Random seed
Returns: DataFrame with columns: source, target, transition
Source code in renalprog/modeling/predict.py
| def create_patient_connections(
data: pd.DataFrame,
clinical: pd.Series,
method: str = "random",
transition_type: str = "early_to_late",
n_connections: Optional[int] = None,
seed: int = 2023
) -> pd.DataFrame:
"""
Create connections between patients for trajectory generation.
Args:
data: Gene expression data
clinical: Clinical stage information
method: Method for creating connections ("random", "nearest", "all")
transition_type: Type of transition ("early_to_late", "early_to_early", etc.)
n_connections: Number of connections to create (None = all possible)
seed: Random seed
Returns:
DataFrame with columns: source, target, transition
"""
logger.info(f"Creating patient connections: {transition_type} using {method} method")
# TODO: Implement connection logic
# Migrate from notebooks/4_1_trajectories.ipynb
raise NotImplementedError(
"create_patient_connections() needs implementation from "
"notebooks/4_1_trajectories.ipynb"
)
|
Example Usage:
from renalprog.modeling.predict import create_patient_connections
# Find optimal patient connections
connections = create_patient_connections(
latent_early=early_latent,
latent_late=late_latent,
method='closest', # or 'random'
output_path=Path("data/processed/patient_connections.csv")
)
print(f"Created {len(connections)} patient pairs")
interpolate_latent_linear
Linear interpolation between latent representations.
interpolate_latent_linear
interpolate_latent_linear(
z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray
Linear interpolation in latent space.
Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps
Returns: Array of interpolated latent vectors (n_steps x latent_dim)
Source code in renalprog/modeling/predict.py
| def interpolate_latent_linear(
z_source: np.ndarray,
z_target: np.ndarray,
n_steps: int = 50
) -> np.ndarray:
"""
Linear interpolation in latent space.
Args:
z_source: Source latent vector
z_target: Target latent vector
n_steps: Number of interpolation steps
Returns:
Array of interpolated latent vectors (n_steps x latent_dim)
"""
alphas = np.linspace(0, 1, n_steps)
interpolated = np.array([
(1 - alpha) * z_source + alpha * z_target
for alpha in alphas
])
return interpolated
|
interpolate_latent_spherical
Spherical (SLERP) interpolation between latent representations.
interpolate_latent_spherical
interpolate_latent_spherical(
z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray
Spherical (SLERP) interpolation in latent space.
Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps
Returns: Array of interpolated latent vectors (n_steps x latent_dim)
Source code in renalprog/modeling/predict.py
| def interpolate_latent_spherical(
z_source: np.ndarray,
z_target: np.ndarray,
n_steps: int = 50
) -> np.ndarray:
"""
Spherical (SLERP) interpolation in latent space.
Args:
z_source: Source latent vector
z_target: Target latent vector
n_steps: Number of interpolation steps
Returns:
Array of interpolated latent vectors (n_steps x latent_dim)
"""
# Normalize vectors
z_source_norm = z_source / np.linalg.norm(z_source)
z_target_norm = z_target / np.linalg.norm(z_target)
# Calculate angle between vectors
omega = np.arccos(np.clip(np.dot(z_source_norm, z_target_norm), -1.0, 1.0))
if omega < 1e-8:
# Vectors are nearly identical, use linear interpolation
return interpolate_latent_linear(z_source, z_target, n_steps)
# SLERP formula
alphas = np.linspace(0, 1, n_steps)
interpolated = np.array([
(np.sin((1 - alpha) * omega) / np.sin(omega)) * z_source +
(np.sin(alpha * omega) / np.sin(omega)) * z_target
for alpha in alphas
])
return interpolated
|
Example:
from renalprog.modeling.predict import interpolate_latent_spherical
import numpy as np
z_start = np.random.randn(10, 128) # 10 samples, 128 latent dims
z_end = np.random.randn(10, 128)
# Spherical interpolation (better for normalized spaces)
trajectory = interpolate_latent_spherical(z_start, z_end, n_steps=50)
# Shape: (10, 50, 128)
Reconstruction Evaluation
evaluate_reconstruction
Comprehensive evaluation of VAE reconstruction quality.
evaluate_reconstruction
evaluate_reconstruction(
real_data: DataFrame,
synthetic_data: DataFrame,
save_path_data: Path,
save_path_figures: Optional[Path] = None,
metadata_path: Path = None,
) -> Tuple[pd.Series, pd.Series]
Comprehensive evaluation of reconstruction quality using SDMetrics.
This function orchestrates a complete quality assessment of synthetic/reconstructed data by computing both diagnostic and quality metrics. It's the main entry point for evaluating VAE reconstructions or VAE+RecNet outputs.
The evaluation includes: 1. Boundary Adherence: Do synthetic values stay within real data bounds? 2. Distribution Similarity: Do synthetic distributions match real distributions? 3. Quality Report: Overall assessment of column shapes and correlations
Args: real_data: Real gene expression data (samples × genes) synthetic_data: Synthetic/reconstructed gene expression data (samples × genes) save_path_data: Directory path to save all metric results (CSV, PKL) save_path_figures: Optional directory path to save visualization plots metadata_path: Path to CSV file used to extract metadata structure (typically the test set CSV file)
Returns: Tuple of (boundary_adherence_series, ks_complement_series): - boundary_adherence_series: Series with boundary adherence scores per gene - ks_complement_series: Series with KS Complement scores per gene
Workflow: 1. Extract metadata from test CSV file 2. Compute diagnostic metrics (boundary adherence) 3. Compute quality metrics (KS complement + quality report) 4. Save all results and visualizations
Output Files: In save_path_data/: - boundary_adherence_per_gene.csv: Per-gene boundary scores - ks_complement_per_gene.csv: Per-gene distribution similarity - quality_report.pkl: Full SDMetrics quality report object
In save_path_figures/ (if provided):
- boundary_adherence_per_gene.{html,png,pdf,svg}
- ks_complement_per_gene.{html,png,pdf,svg}
Interpretation: - Higher scores are better for both metrics (range: 0.0 to 1.0) - Boundary Adherence: Checks if synthetic data stays in valid ranges - KS Complement: Checks if distributions match (more stringent) - Good reconstruction: BA > 0.95, KS > 0.85 - Excellent reconstruction: BA > 0.99, KS > 0.90
Example: >>> ba_scores, ks_scores = evaluate_reconstruction( ... real_data=X_test, ... synthetic_data=vae_reconstruction, ... save_path_data="results/vae_eval/", ... save_path_figures="figures/vae_eval/", ... metadata_path="data/X_test.csv" ... ) >>> print(f"Mean BA: {ba_scores.mean():.4f}, Mean KS: {ks_scores.mean():.4f}") Mean BA: 0.9823, Mean KS: 0.8756
Note: - Ensure real_data and synthetic_data have identical column names and order - The metadata_path CSV should have the same structure as real_data - This function is used by scripts/pipeline_steps/3_check_reconstruction.py - Both DataFrames should have samples as rows and genes as columns
Raises: ValueError: If data shapes don't match or columns don't align FileNotFoundError: If metadata_path doesn't exist
Source code in renalprog/modeling/predict.py
| def evaluate_reconstruction(
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
save_path_data: Path,
save_path_figures: Optional[Path] = None,
metadata_path: Path = None,
) -> Tuple[pd.Series, pd.Series]:
"""
Comprehensive evaluation of reconstruction quality using SDMetrics.
This function orchestrates a complete quality assessment of synthetic/reconstructed
data by computing both diagnostic and quality metrics. It's the main entry point
for evaluating VAE reconstructions or VAE+RecNet outputs.
The evaluation includes:
1. Boundary Adherence: Do synthetic values stay within real data bounds?
2. Distribution Similarity: Do synthetic distributions match real distributions?
3. Quality Report: Overall assessment of column shapes and correlations
Args:
real_data: Real gene expression data (samples × genes)
synthetic_data: Synthetic/reconstructed gene expression data (samples × genes)
save_path_data: Directory path to save all metric results (CSV, PKL)
save_path_figures: Optional directory path to save visualization plots
metadata_path: Path to CSV file used to extract metadata structure
(typically the test set CSV file)
Returns:
Tuple of (boundary_adherence_series, ks_complement_series):
- boundary_adherence_series: Series with boundary adherence scores per gene
- ks_complement_series: Series with KS Complement scores per gene
Workflow:
1. Extract metadata from test CSV file
2. Compute diagnostic metrics (boundary adherence)
3. Compute quality metrics (KS complement + quality report)
4. Save all results and visualizations
Output Files:
In save_path_data/:
- boundary_adherence_per_gene.csv: Per-gene boundary scores
- ks_complement_per_gene.csv: Per-gene distribution similarity
- quality_report.pkl: Full SDMetrics quality report object
In save_path_figures/ (if provided):
- boundary_adherence_per_gene.{html,png,pdf,svg}
- ks_complement_per_gene.{html,png,pdf,svg}
Interpretation:
- Higher scores are better for both metrics (range: 0.0 to 1.0)
- Boundary Adherence: Checks if synthetic data stays in valid ranges
- KS Complement: Checks if distributions match (more stringent)
- Good reconstruction: BA > 0.95, KS > 0.85
- Excellent reconstruction: BA > 0.99, KS > 0.90
Example:
>>> ba_scores, ks_scores = evaluate_reconstruction(
... real_data=X_test,
... synthetic_data=vae_reconstruction,
... save_path_data="results/vae_eval/",
... save_path_figures="figures/vae_eval/",
... metadata_path="data/X_test.csv"
... )
>>> print(f"Mean BA: {ba_scores.mean():.4f}, Mean KS: {ks_scores.mean():.4f}")
Mean BA: 0.9823, Mean KS: 0.8756
Note:
- Ensure real_data and synthetic_data have identical column names and order
- The metadata_path CSV should have the same structure as real_data
- This function is used by scripts/pipeline_steps/3_check_reconstruction.py
- Both DataFrames should have samples as rows and genes as columns
Raises:
ValueError: If data shapes don't match or columns don't align
FileNotFoundError: If metadata_path doesn't exist
"""
logger.info("=" * 80)
logger.info("RECONSTRUCTION QUALITY EVALUATION")
logger.info("=" * 80)
logger.info(f"Real data shape: {real_data.shape}")
logger.info(f"Synthetic data shape: {synthetic_data.shape}")
logger.info(f"Saving results to: {save_path_data}")
if save_path_figures:
logger.info(f"Saving figures to: {save_path_figures}")
# Validate inputs
if real_data.shape != synthetic_data.shape:
raise ValueError(
f"Data shape mismatch: real {real_data.shape} vs synthetic {synthetic_data.shape}"
)
if not all(real_data.columns == synthetic_data.columns):
raise ValueError("Column names must match between real and synthetic data")
# Step 1: Extract metadata in SDMetrics format
logger.info("\n[1/3] Extracting metadata...")
metadata_sd = get_metadata(metadata_path)
logger.info(f"Metadata extracted for {len(metadata_sd['columns'])} genes")
# Step 2: Compute diagnostic metrics
logger.info("\n[2/3] Computing diagnostic metrics...")
df_ba = diagnostic_metrics(
real_data=real_data,
synthetic_data=synthetic_data,
save_path_data=save_path_data,
save_path_figures=save_path_figures
)
# Step 3: Compute quality metrics
logger.info("\n[3/3] Computing quality metrics...")
df_ks = quality_metrics(
real_data=real_data,
synthetic_data=synthetic_data,
metadata=metadata_sd,
save_path_data=save_path_data,
save_path_figures=save_path_figures
)
# Final summary
logger.info("\n" + "=" * 80)
logger.info("EVALUATION COMPLETE - SUMMARY")
logger.info("=" * 80)
logger.info(f"Boundary Adherence - Mean: {df_ba.mean():.4f}, Median: {df_ba.median():.4f}")
logger.info(f"KS Complement - Mean: {df_ks.mean():.4f}, Median: {df_ks.median():.4f}")
# Quality assessment
ba_quality = "Excellent" if df_ba.mean() > 0.99 else "Good" if df_ba.mean() > 0.95 else "Fair" if df_ba.mean() > 0.90 else "Poor"
ks_quality = "Excellent" if df_ks.mean() > 0.90 else "Good" if df_ks.mean() > 0.85 else "Fair" if df_ks.mean() > 0.75 else "Poor"
logger.info(f"Overall Assessment - Boundary: {ba_quality}, Distribution: {ks_quality}")
logger.info(f"Results saved to: {save_path_data}")
logger.info("=" * 80)
return df_ba, df_ks
|
Example Usage:
from renalprog.modeling.predict import evaluate_reconstruction
# Evaluate reconstruction quality
metrics = evaluate_reconstruction(
model=model,
original_data=test_expr.values,
device='cuda',
output_dir=Path("reports/reconstruction_eval")
)
print(f"MSE: {metrics['mse']:.4f}")
print(f"Pearson R: {metrics['pearson_mean']:.4f}")
print(f"Cosine similarity: {metrics['cosine_mean']:.4f}")
Quality Metrics
diagnostic_metrics
Calculate diagnostic metrics for model evaluation.
diagnostic_metrics
diagnostic_metrics(
real_data: DataFrame,
synthetic_data: DataFrame,
save_path_data: Path,
save_path_figures: Optional[Path] = None,
) -> pd.Series
Calculate diagnostic metrics to assess synthetic data quality.
This function computes the Boundary Adherence metric for each gene, which measures whether synthetic values respect the min/max boundaries of real data. This is a critical diagnostic to detect mode collapse or distribution shift.
Args: real_data: Real gene expression data (samples × genes) synthetic_data: Synthetic/reconstructed gene expression data (samples × genes) save_path_data: Directory path to save metric CSV results save_path_figures: Optional directory path to save visualization plots
Returns: Series with boundary adherence scores per gene (index=gene, values=scores). Scores range from 0.0 (worst) to 1.0 (best).
Metric Details: Boundary Adherence per Gene: - 1.0 (best): All synthetic values are within [min, max] of real data - 0.0 (worst): No synthetic values fall within real data boundaries - Values between 0-1 indicate partial adherence
Saves: - {save_path_data}/boundary_adherence_per_gene.csv: Per-gene scores - {save_path_figures}/boundary_adherence_per_gene.html: Interactive histogram - {save_path_figures}/boundary_adherence_per_gene.{png,pdf,svg}: Static plots
Example: >>> ba_scores = diagnostic_metrics(X_real, X_synthetic, "results/", "figures/") >>> print(f"Mean adherence: {ba_scores.mean():.4f}") Mean adherence: 0.9823
Source code in renalprog/modeling/predict.py
| def diagnostic_metrics(
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
save_path_data: Path,
save_path_figures: Optional[Path] = None
) -> pd.Series:
"""
Calculate diagnostic metrics to assess synthetic data quality.
This function computes the Boundary Adherence metric for each gene, which
measures whether synthetic values respect the min/max boundaries of real data.
This is a critical diagnostic to detect mode collapse or distribution shift.
Args:
real_data: Real gene expression data (samples × genes)
synthetic_data: Synthetic/reconstructed gene expression data (samples × genes)
save_path_data: Directory path to save metric CSV results
save_path_figures: Optional directory path to save visualization plots
Returns:
Series with boundary adherence scores per gene (index=gene, values=scores).
Scores range from 0.0 (worst) to 1.0 (best).
Metric Details:
Boundary Adherence per Gene:
- 1.0 (best): All synthetic values are within [min, max] of real data
- 0.0 (worst): No synthetic values fall within real data boundaries
- Values between 0-1 indicate partial adherence
Saves:
- {save_path_data}/boundary_adherence_per_gene.csv: Per-gene scores
- {save_path_figures}/boundary_adherence_per_gene.html: Interactive histogram
- {save_path_figures}/boundary_adherence_per_gene.{png,pdf,svg}: Static plots
Example:
>>> ba_scores = diagnostic_metrics(X_real, X_synthetic, "results/", "figures/")
>>> print(f"Mean adherence: {ba_scores.mean():.4f}")
Mean adherence: 0.9823
"""
logger.info("=" * 60)
logger.info("DIAGNOSTIC METRICS: Boundary Adherence")
logger.info("=" * 60)
logger.info(f"Evaluating {real_data.shape[1]} genes across {real_data.shape[0]} samples")
# Calculate boundary adherence for each gene
# This measures what percentage of synthetic values fall within the
# [min, max] range observed in the real data
ba_dict = {}
for gene_i in tqdm(real_data.columns, desc='Computing Boundary Adherence'):
# Compute metric: % of synthetic values within [min, max] of real values
ba_i = BoundaryAdherence.compute(
real_data=real_data[gene_i],
synthetic_data=synthetic_data[gene_i]
)
ba_dict[gene_i] = ba_i
# Convert to Series for easy analysis
df_ba = pd.Series(ba_dict, name='boundary_adherence')
# Save results
output_csv = os.path.join(save_path_data, 'boundary_adherence_per_gene.csv')
df_ba.to_csv(output_csv)
logger.info(f"Saved results to: {output_csv}")
# Log summary statistics
logger.info(f"Mean Boundary Adherence: {df_ba.mean():.4f}")
logger.info(f"Median Boundary Adherence: {df_ba.median():.4f}")
logger.info(f"Min Boundary Adherence: {df_ba.min():.4f}")
logger.info(f"Max Boundary Adherence: {df_ba.max():.4f}")
logger.info(f"Genes with perfect adherence (1.0): {(df_ba == 1.0).sum()}/{len(df_ba)} ({100*(df_ba == 1.0).sum()/len(df_ba):.1f}%)")
# Generate visualizations if output directory provided
if save_path_figures is not None:
logger.info("Generating visualizations...")
# Create interactive histogram
fig = px.histogram(
df_ba,
x='boundary_adherence',
nbins=50,
title='Distribution of Boundary Adherence Scores per Gene',
labels={'boundary_adherence': 'Boundary Adherence Score'},
template='plotly_white'
)
fig.update_layout(
xaxis_title='Boundary Adherence Score',
yaxis_title='Number of Genes',
showlegend=False
)
# Save in multiple formats
html_path = os.path.join(save_path_figures, 'boundary_adherence_per_gene.html')
fig.write_html(html_path)
logger.info(f" Saved interactive plot: {html_path}")
for format_ext in ['png', 'pdf', 'svg']:
img_path = os.path.join(save_path_figures, f'boundary_adherence_per_gene.{format_ext}')
fig.write_image(img_path, scale=2)
logger.info(f" Saved {format_ext.upper()} plot: {img_path}")
logger.info("Diagnostic metrics calculation complete")
logger.info("=" * 60)
return df_ba
|
quality_metrics
Calculate quality metrics for generated trajectories.
quality_metrics
quality_metrics(
real_data: DataFrame,
synthetic_data: DataFrame,
metadata: Dict[str, Any],
save_path_data: Path,
save_path_figures: Optional[Path] = None,
) -> pd.Series
Calculate quality metrics to assess synthetic data fidelity.
This function computes two key metrics: 1. Quality Report: Overall assessment of column shapes and pair-wise trends 2. KS Complement: Per-gene similarity of marginal distributions
These metrics evaluate how well the synthetic data captures the statistical properties of the real data, beyond just staying within boundaries.
Args: real_data: Real gene expression data (samples × genes) synthetic_data: Synthetic/reconstructed gene expression data (samples × genes) metadata: Metadata dictionary from get_metadata() for SDMetrics save_path_data: Directory path to save metric results save_path_figures: Optional directory path to save visualization plots
Returns: Series with KS Complement scores per gene (index=gene, values=scores). Scores range from 0.0 (worst) to 1.0 (best).
Metrics Details: Column Shapes (in Quality Report): - Measures overall distribution similarity per column - Higher scores indicate better shape matching
Column Pair Trends (in Quality Report):
- Measures correlation and relationship preservation
- Higher scores indicate better trend matching
KS Complement (per gene):
- 1.0 (best): Real and synthetic distributions are identical
- 0.0 (worst): Distributions are maximally different
- Based on Kolmogorov-Smirnov test
Saves: - {save_path_data}/quality_report.pkl: Full SDMetrics quality report - {save_path_data}/ks_complement_per_gene.csv: Per-gene KS scores - {save_path_figures}/ks_complement_per_gene.html: Interactive histogram - {save_path_figures}/ks_complement_per_gene.{png,pdf,svg}: Static plots
Example: >>> ks_scores = quality_metrics(X_real, X_synth, metadata, "results/", "figs/") >>> print(f"Mean KS Complement: {ks_scores.mean():.4f}") Mean KS Complement: 0.8756
Source code in renalprog/modeling/predict.py
| def quality_metrics(
real_data: pd.DataFrame,
synthetic_data: pd.DataFrame,
metadata: Dict[str, Any],
save_path_data: Path,
save_path_figures: Optional[Path] = None
) -> pd.Series:
"""
Calculate quality metrics to assess synthetic data fidelity.
This function computes two key metrics:
1. Quality Report: Overall assessment of column shapes and pair-wise trends
2. KS Complement: Per-gene similarity of marginal distributions
These metrics evaluate how well the synthetic data captures the statistical
properties of the real data, beyond just staying within boundaries.
Args:
real_data: Real gene expression data (samples × genes)
synthetic_data: Synthetic/reconstructed gene expression data (samples × genes)
metadata: Metadata dictionary from get_metadata() for SDMetrics
save_path_data: Directory path to save metric results
save_path_figures: Optional directory path to save visualization plots
Returns:
Series with KS Complement scores per gene (index=gene, values=scores).
Scores range from 0.0 (worst) to 1.0 (best).
Metrics Details:
Column Shapes (in Quality Report):
- Measures overall distribution similarity per column
- Higher scores indicate better shape matching
Column Pair Trends (in Quality Report):
- Measures correlation and relationship preservation
- Higher scores indicate better trend matching
KS Complement (per gene):
- 1.0 (best): Real and synthetic distributions are identical
- 0.0 (worst): Distributions are maximally different
- Based on Kolmogorov-Smirnov test
Saves:
- {save_path_data}/quality_report.pkl: Full SDMetrics quality report
- {save_path_data}/ks_complement_per_gene.csv: Per-gene KS scores
- {save_path_figures}/ks_complement_per_gene.html: Interactive histogram
- {save_path_figures}/ks_complement_per_gene.{png,pdf,svg}: Static plots
Example:
>>> ks_scores = quality_metrics(X_real, X_synth, metadata, "results/", "figs/")
>>> print(f"Mean KS Complement: {ks_scores.mean():.4f}")
Mean KS Complement: 0.8756
"""
import os
from tqdm import tqdm
import plotly.express as px
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.single_column import KSComplement
logger.info("=" * 60)
logger.info("QUALITY METRICS: Distribution Similarity")
logger.info("=" * 60)
# Generate comprehensive quality report
# This evaluates:
# 1. Column Shapes: How well distributions match per gene
# 2. Column Pair Trends: How well correlations are preserved
logger.info("Generating SDMetrics Quality Report...")
q_report = QualityReport()
q_report.generate(real_data, synthetic_data, metadata)
# Save quality report object for later analysis
report_path = os.path.join(save_path_data, 'quality_report.pkl')
q_report.save(report_path)
logger.info(f"Saved quality report to: {report_path}")
# Get overall quality score from the report
overall_score = q_report.get_score()
logger.info(f"Overall Quality Score: {overall_score:.4f}")
# Calculate KS Complement for each gene
# This measures similarity of marginal distributions (1D histograms)
# KS Complement = 1 - KS statistic, where KS statistic measures max difference between CDFs
logger.info(f"Computing KS Complement for {real_data.shape[1]} genes...")
ks_dict = {}
for gene_i in tqdm(real_data.columns, desc='Computing KS Complement'):
# KS Complement measures how similar the empirical cumulative distribution functions are
# Higher values mean the distributions are more similar
ks_i = KSComplement.compute(
real_data=real_data[gene_i],
synthetic_data=synthetic_data[gene_i]
)
ks_dict[gene_i] = ks_i
# Convert to Series for analysis
df_ks = pd.Series(ks_dict, name='ks_complement')
# Save results
output_csv = os.path.join(save_path_data, 'ks_complement_per_gene.csv')
df_ks.to_csv(output_csv)
logger.info(f"Saved results to: {output_csv}")
# Log summary statistics
logger.info(f"Mean KS Complement: {df_ks.mean():.4f}")
logger.info(f"Median KS Complement: {df_ks.median():.4f}")
logger.info(f"Min KS Complement: {df_ks.min():.4f}")
logger.info(f"Max KS Complement: {df_ks.max():.4f}")
logger.info(f"Genes with KS > 0.9: {(df_ks > 0.9).sum()}/{len(df_ks)} ({100*(df_ks > 0.9).sum()/len(df_ks):.1f}%)")
# Generate visualizations if output directory provided
if save_path_figures is not None:
logger.info("Generating visualizations...")
# Create interactive histogram
fig = px.histogram(
df_ks,
x='ks_complement',
nbins=50,
title='Distribution of KS Complement Scores per Gene',
labels={'ks_complement': 'KS Complement Score'},
template='plotly_white'
)
fig.update_layout(
xaxis_title='KS Complement Score (Distribution Similarity)',
yaxis_title='Number of Genes',
showlegend=False
)
# Add reference line at 0.9 (high quality threshold)
fig.add_vline(
x=0.9,
line_dash="dash",
line_color="red",
annotation_text="High Quality (0.9)"
)
# Save in multiple formats
html_path = os.path.join(save_path_figures, 'ks_complement_per_gene.html')
fig.write_html(html_path)
logger.info(f" Saved interactive plot: {html_path}")
for format_ext in ['png', 'pdf', 'svg']:
img_path = os.path.join(save_path_figures, f'ks_complement_per_gene.{format_ext}')
fig.write_image(img_path, scale=2)
logger.info(f" Saved {format_ext.upper()} plot: {img_path}")
logger.info("Quality metrics calculation complete")
logger.info("=" * 60)
return df_ks
|
Trajectory Classification
classify_trajectories
Classify disease progression trajectories as progressing vs. non-progressing.
classify_trajectories
classify_trajectories(
classifier,
trajectory_data: Dict[str, DataFrame],
gene_subset: Optional[List[str]] = None,
) -> pd.DataFrame
Apply stage classifier to synthetic trajectories.
Args: classifier: Trained classifier model trajectory_data: Dictionary of patient pair to trajectory DataFrames gene_subset: Optional subset of genes to use for classification
Returns: DataFrame with classification results for each trajectory point
Source code in renalprog/modeling/predict.py
| def classify_trajectories(
classifier,
trajectory_data: Dict[str, pd.DataFrame],
gene_subset: Optional[List[str]] = None
) -> pd.DataFrame:
"""
Apply stage classifier to synthetic trajectories.
Args:
classifier: Trained classifier model
trajectory_data: Dictionary of patient pair to trajectory DataFrames
gene_subset: Optional subset of genes to use for classification
Returns:
DataFrame with classification results for each trajectory point
"""
logger.info("Classifying trajectory points")
# TODO: Implement trajectory classification
# Migrate from notebooks/kirc_classification_trajectory.ipynb
raise NotImplementedError(
"classify_trajectories() needs implementation from "
"notebooks/kirc_classification_trajectory.ipynb"
)
|
Example Usage:
from renalprog.modeling.predict import classify_trajectories
# Train classifier on trajectories
classifier, metrics = classify_trajectories(
trajectories=trajectory_data,
labels=progression_labels,
output_dir=Path("models/trajectory_classifier")
)
print(f"Classification accuracy: {metrics['accuracy']:.3f}")
print(f"AUC-ROC: {metrics['auc_roc']:.3f}")
Network Analysis
build_trajectory_network
Build network graph of patient trajectories.
build_trajectory_network
build_trajectory_network(
patient_links: DataFrame,
) -> Tuple[Dict[str, List[str]], List[List[str]]]
Build trajectory network and find all complete disease progression paths.
Constructs a directed graph from patient links and identifies all possible complete trajectories from root nodes (earliest stage patients not appearing as targets) to leaf nodes (latest stage patients not appearing as sources).
Args: patient_links: DataFrame with 'source' and 'target' columns from linking functions
Returns: Tuple of: - network: Dict mapping each source patient to list of target patients - trajectories: List of complete trajectories, where each trajectory is a list of patient IDs ordered from earliest to latest stage
Network Structure: - Adjacency list representation: {source: [target1, target2, ...]} - Directed edges from earlier to later stages - Allows multiple outgoing edges (one patient → multiple next-stage patients)
Trajectory Discovery: - Uses depth-first search from root nodes - Root nodes: Patients in 'source' but not in 'target' (stage I or early) - Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late) - Each trajectory represents a complete disease progression path
Example: >>> network, trajectories = build_trajectory_network(patient_links) >>> print(f"Network has {len(network)} nodes") >>> print(f"Found {len(trajectories)} complete trajectories") >>> print(f"Example trajectory: {trajectories[0]}") Network has 500 nodes Found 234 complete trajectories Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']
Trajectory Characteristics: - Length varies based on how many stages the path spans - Typical lengths: 2-4 patients for I→II→III→IV progressions - Length 2 for early→late progressions - Patients can appear in multiple trajectories
Note: - Cycles are prevented during trajectory search - All paths from root to leaf are enumerated - Trajectories respect chronological disease progression
Source code in renalprog/modeling/predict.py
| def build_trajectory_network(
patient_links: pd.DataFrame
) -> Tuple[Dict[str, List[str]], List[List[str]]]:
"""
Build trajectory network and find all complete disease progression paths.
Constructs a directed graph from patient links and identifies all possible
complete trajectories from root nodes (earliest stage patients not appearing
as targets) to leaf nodes (latest stage patients not appearing as sources).
Args:
patient_links: DataFrame with 'source' and 'target' columns from linking functions
Returns:
Tuple of:
- network: Dict mapping each source patient to list of target patients
- trajectories: List of complete trajectories, where each trajectory is a
list of patient IDs ordered from earliest to latest stage
Network Structure:
- Adjacency list representation: {source: [target1, target2, ...]}
- Directed edges from earlier to later stages
- Allows multiple outgoing edges (one patient → multiple next-stage patients)
Trajectory Discovery:
- Uses depth-first search from root nodes
- Root nodes: Patients in 'source' but not in 'target' (stage I or early)
- Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late)
- Each trajectory represents a complete disease progression path
Example:
>>> network, trajectories = build_trajectory_network(patient_links)
>>> print(f"Network has {len(network)} nodes")
>>> print(f"Found {len(trajectories)} complete trajectories")
>>> print(f"Example trajectory: {trajectories[0]}")
Network has 500 nodes
Found 234 complete trajectories
Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']
Trajectory Characteristics:
- Length varies based on how many stages the path spans
- Typical lengths: 2-4 patients for I→II→III→IV progressions
- Length 2 for early→late progressions
- Patients can appear in multiple trajectories
Note:
- Cycles are prevented during trajectory search
- All paths from root to leaf are enumerated
- Trajectories respect chronological disease progression
"""
logger.info("Building trajectory network from patient links")
sources = patient_links['source']
targets = patient_links['target']
# Build network adjacency list
network = {}
for source, target in zip(sources, targets):
if source not in network:
network[source] = []
network[source].append(target)
logger.info(f"Network built: {len(network)} source nodes")
# Find root nodes (patients who are sources but never targets)
unique_sources = set(sources) - set(targets)
logger.info(f"Found {len(unique_sources)} root nodes (earliest stage patients)")
# Recursively find all trajectories from each root
def find_trajectories(start_node: str, visited: Optional[List[str]] = None) -> List[List[str]]:
"""Depth-first search to find all paths from start_node to leaf nodes."""
if visited is None:
visited = []
visited.append(start_node)
# If node has no outgoing edges, this is a leaf node - return path
if start_node not in network:
return [visited]
# Recursively explore all targets
trajectories = []
for target in network[start_node]:
if target not in visited: # Avoid cycles
new_visited = visited.copy()
trajectories.extend(find_trajectories(target, new_visited))
return trajectories
# Find all trajectories starting from each root
all_trajectories = []
if len(unique_sources) == 0:
# No clear root nodes - this happens with early→late transitions where
# patients can be both sources and targets. In this case, each source→target
# pair is already a complete 2-patient trajectory.
logger.info("No root nodes found (typical for early→late transitions).")
logger.info("Using each source→target pair as a complete trajectory.")
for source, target in zip(sources, targets):
all_trajectories.append([source, target])
else:
# Standard case: multi-stage progressions (I→II→III→IV)
for source in unique_sources:
all_trajectories.extend(find_trajectories(source))
logger.info(f"Discovered {len(all_trajectories)} complete disease progression trajectories")
# Log trajectory length statistics only if we have trajectories
if len(all_trajectories) > 0:
traj_lengths = [len(t) for t in all_trajectories]
logger.info(f"Trajectory lengths - Min: {min(traj_lengths)}, Max: {max(traj_lengths)}, "
f"Mean: {np.mean(traj_lengths):.1f}")
else:
logger.warning("No trajectories found!")
return network, all_trajectories
|
link_patients_closest
Link patients using closest neighbor strategy.
link_patients_closest
link_patients_closest(
transitions_df: DataFrame,
start_with_first_stage: bool = True,
early_late: bool = False,
closest: bool = True,
) -> pd.DataFrame
Link patients by selecting closest (or farthest) matches across stages.
For each patient at a source stage, this function identifies the closest (or farthest) patient at the target stage, considering metadata constraints (gender, race). This creates one-to-one patient linkages that form the basis for trajectory construction.
Args: transitions_df: DataFrame from calculate_all_possible_transitions() containing all possible patient pairs with distances start_with_first_stage: If True, build forward trajectories (early→late) If False, build backward trajectories (late→early) early_late: If True, uses early/late groupings. If False, uses I-IV stages closest: If True, connect closest patients. If False, connect farthest patients
Returns: DataFrame with selected patient links, containing one row per source patient with their optimal target patient match. Includes all columns from transitions_df.
Selection Strategy: - Forward (start_with_first_stage=True): For each source, find optimal target - Backward (start_with_first_stage=False): For each target, find optimal source - Closest (closest=True): Minimum distance match - Farthest (closest=False): Maximum distance match
Metadata Stratification: Links are selected independently within each combination of: - Gender (MALE, FEMALE) - Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE) This ensures demographic consistency in trajectories.
Example: >>> links = link_patients_closest( ... transitions_df=all_transitions, ... start_with_first_stage=True, ... closest=True ... ) >>> print(f"Created {len(links)} patient links") Created 234 patient links
Note: - Processes transitions in order for forward: I→II→III→IV - Processes in reverse for backward: IV→III→II→I - Each patient appears at most once as a source in the result
Source code in renalprog/modeling/predict.py
| def link_patients_closest(
transitions_df: pd.DataFrame,
start_with_first_stage: bool = True,
early_late: bool = False,
closest: bool = True
) -> pd.DataFrame:
"""
Link patients by selecting closest (or farthest) matches across stages.
For each patient at a source stage, this function identifies the closest
(or farthest) patient at the target stage, considering metadata constraints
(gender, race). This creates one-to-one patient linkages that form the basis
for trajectory construction.
Args:
transitions_df: DataFrame from calculate_all_possible_transitions()
containing all possible patient pairs with distances
start_with_first_stage: If True, build forward trajectories (early→late)
If False, build backward trajectories (late→early)
early_late: If True, uses early/late groupings. If False, uses I-IV stages
closest: If True, connect closest patients. If False, connect farthest patients
Returns:
DataFrame with selected patient links, containing one row per source patient
with their optimal target patient match. Includes all columns from transitions_df.
Selection Strategy:
- Forward (start_with_first_stage=True): For each source, find optimal target
- Backward (start_with_first_stage=False): For each target, find optimal source
- Closest (closest=True): Minimum distance match
- Farthest (closest=False): Maximum distance match
Metadata Stratification:
Links are selected independently within each combination of:
- Gender (MALE, FEMALE)
- Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE)
This ensures demographic consistency in trajectories.
Example:
>>> links = link_patients_closest(
... transitions_df=all_transitions,
... start_with_first_stage=True,
... closest=True
... )
>>> print(f"Created {len(links)} patient links")
Created 234 patient links
Note:
- Processes transitions in order for forward: I→II→III→IV
- Processes in reverse for backward: IV→III→II→I
- Each patient appears at most once as a source in the result
"""
logger.info("Linking patients by closest/farthest matches")
logger.info(f"Direction: {'Forward' if start_with_first_stage else 'Backward'}")
logger.info(f"Strategy: {'Closest' if closest else 'Farthest'}")
# Define transition order based on direction
if start_with_first_stage and not early_late:
transitions_possible = ['1_to_2', '2_to_3', '3_to_4']
elif not start_with_first_stage and not early_late:
transitions_possible = ['3_to_4', '2_to_3', '1_to_2']
elif early_late:
transitions_possible = ['early_to_late']
# 0 for closest (smallest distance), -1 for farthest (largest distance)
idx = 0 if closest else -1
# Find closest/farthest patient for each source patient
closest_list = []
for transition_i in transitions_possible:
transition_df_i = transitions_df[transitions_df['transition'] == transition_i]
logger.info(f"Processing transition {transition_i}: {len(transition_df_i)} pairs")
# Iterate through all metadata combinations
for gender_i in ['FEMALE', 'MALE']:
df_gender_i = transition_df_i.query(f"source_gender == '{gender_i}'")
for race_i in ['ASIAN', 'BLACK OR AFRICAN AMERICAN', 'WHITE']:
df_race_i = df_gender_i.query(f"source_race == '{race_i}'")
if df_race_i.empty:
continue
# Get unique patients to link
unique_sources = df_race_i['source'].unique()
unique_targets = df_race_i['target'].unique()
use_uniques = unique_sources if start_with_first_stage else unique_targets
use_column = 'source' if start_with_first_stage else 'target'
# Find closest/farthest match for each patient
for pat_i in use_uniques:
pat_matches = df_race_i[df_race_i[use_column] == pat_i]
if len(pat_matches) > 0:
# Sort by distance and select first (closest) or last (farthest)
best_match = pat_matches.sort_values('distance').iloc[idx]
closest_list.append(best_match)
# Convert to DataFrame
closest_df = pd.DataFrame(closest_list)
closest_df.reset_index(drop=True, inplace=True)
logger.info(f"Created {len(closest_df)} patient links")
return closest_df
|
link_patients_random
Link patients randomly (for control/comparison).
link_patients_random
link_patients_random(
results_df: DataFrame,
start_with_first_stage: bool = True,
link_next: int = 5,
transitions_possible: Optional[List[str]] = None,
) -> pd.DataFrame
Link patients to multiple random targets at the next stage.
Instead of linking each patient to only their closest match, this function randomly samples multiple patients at the next stage to link to each source patient. This creates a one-to-many mapping useful for generating multiple trajectory samples.
Parameters:
| Name | Type | Description | Default |
results_df | DataFrame | DataFrame with possible sources and targets, their metadata, and distance. | required |
start_with_first_stage | bool | If True, initiate trajectories with first stage as sources. If False, initiate trajectories with last stage as sources. | True |
link_next | int | Number of patients at next stage to randomly link to each patient of current stage. | 5 |
transitions_possible | list | List of transitions to process (e.g., ['1_to_2', '2_to_3']). If None, defaults to ['early_to_late']. | None |
Returns:
| Type | Description |
DataFrame | DataFrame with randomly sampled patient links for each transition. Contains multiple rows per source patient (up to link_next). |
Notes
- Random sampling is primarily performed for WHITE race patients due to sample size
- If fewer than link_next targets are available, all available targets are selected
- Patients from other races are included with all their possible connections
- Empty DataFrame is returned if no WHITE patients are found
Source code in renalprog/modeling/predict.py
| def link_patients_random(
results_df: pd.DataFrame,
start_with_first_stage: bool = True,
link_next: int = 5,
transitions_possible: Optional[List[str]] = None
) -> pd.DataFrame:
"""
Link patients to multiple random targets at the next stage.
Instead of linking each patient to only their closest match, this function randomly
samples multiple patients at the next stage to link to each source patient. This
creates a one-to-many mapping useful for generating multiple trajectory samples.
Parameters
----------
results_df : pd.DataFrame
DataFrame with possible sources and targets, their metadata, and distance.
start_with_first_stage : bool, default=True
If True, initiate trajectories with first stage as sources.
If False, initiate trajectories with last stage as sources.
link_next : int, default=5
Number of patients at next stage to randomly link to each patient of current stage.
transitions_possible : list, optional
List of transitions to process (e.g., ['1_to_2', '2_to_3']).
If None, defaults to ['early_to_late'].
Returns
-------
pd.DataFrame
DataFrame with randomly sampled patient links for each transition.
Contains multiple rows per source patient (up to link_next).
Notes
-----
- Random sampling is primarily performed for WHITE race patients due to sample size
- If fewer than link_next targets are available, all available targets are selected
- Patients from other races are included with all their possible connections
- Empty DataFrame is returned if no WHITE patients are found
"""
# Set default transitions if not provided
if transitions_possible is None:
transitions_possible = ['early_to_late']
# Get unique genders and races
unique_genders = results_df['source_gender'].unique().tolist()
# Get unique races
unique_races = results_df['source_race'].unique().tolist()
if 'WHITE' in unique_races:
unique_races.remove('WHITE')
# transition:
samples = []
for transition_i in transitions_possible:
transition_df_i = results_df[results_df['transition'] == transition_i]
for gender_i in unique_genders:
df_samples_i = transition_df_i.query(
f"source_gender == '{gender_i}' & source_race == 'WHITE'") # we can only do this for the whites since these are the only ones with enough samples
if df_samples_i.empty:
print(f"Warning: No WHITE patients found for gender {gender_i} in transition {transition_i}")
continue
unique_sources_i = np.unique(df_samples_i['source']).tolist()
unique_targets_i = np.unique(df_samples_i['target']).tolist()
use_uniques = unique_sources_i if start_with_first_stage else unique_targets_i
use_source_target = 'source' if start_with_first_stage else 'target'
for pat_i in use_uniques:
sample_i = df_samples_i.loc[df_samples_i[use_source_target] == pat_i]
if len(sample_i) >= link_next:
sample_i = sample_i.sample(
link_next) # Sample a number of patients at next stage to link to each patient of current stage
else:
sample_i = sample_i.sample(len(sample_i)) # Sample all available patients if less than link_next
samples.append(sample_i)
# Check if samples list is empty
if not samples:
print("Warning: No samples found for WHITE race. Returning empty DataFrame.")
return pd.DataFrame(columns=results_df.columns)
# Turn samples into dataframe:
samples_df = pd.concat(samples)
# Add the rest of the races
if unique_races:
samples_df = pd.concat(
[
samples_df,
results_df[results_df['source_race'].isin(unique_races)]
]
)
samples_df.reset_index(drop=True, inplace=True)
return samples_df
|
Dynamic Analysis
dynamic_enrichment_analysis
Perform pathway enrichment along trajectory timepoints.
dynamic_enrichment_analysis
dynamic_enrichment_analysis(
trajectory_dir: Path,
pathways_file: Path,
output_dir: Path,
cancer_type: str = "kirc",
) -> pd.DataFrame
Perform dynamic enrichment analysis on synthetic trajectories.
This orchestrates: 1. DESeq2 analysis on each trajectory point 2. GSEA on differential expression results 3. Aggregation of enrichment across trajectories
Args: trajectory_dir: Directory containing trajectory CSV files pathways_file: Path to pathway GMT file output_dir: Directory to save results cancer_type: Cancer type identifier
Returns: DataFrame with aggregated enrichment results
Source code in renalprog/modeling/predict.py
| def dynamic_enrichment_analysis(
trajectory_dir: Path,
pathways_file: Path,
output_dir: Path,
cancer_type: str = "kirc"
) -> pd.DataFrame:
"""
Perform dynamic enrichment analysis on synthetic trajectories.
This orchestrates:
1. DESeq2 analysis on each trajectory point
2. GSEA on differential expression results
3. Aggregation of enrichment across trajectories
Args:
trajectory_dir: Directory containing trajectory CSV files
pathways_file: Path to pathway GMT file
output_dir: Directory to save results
cancer_type: Cancer type identifier
Returns:
DataFrame with aggregated enrichment results
"""
logger.info(f"Running dynamic enrichment analysis for {cancer_type}")
# TODO: Implement orchestration
# Migrate from src_deseq_and_gsea_NCSR/full_bash.sh and related scripts
raise NotImplementedError(
"dynamic_enrichment_analysis() needs implementation. "
"Migrate orchestration from src_deseq_and_gsea_NCSR/full_bash.sh, "
"py_deseq.py, and trajectory_analysis.py"
)
|
calculate_all_possible_transitions
Calculate all possible patient-to-patient transitions.
calculate_all_possible_transitions
calculate_all_possible_transitions(
data: DataFrame,
metadata_selection: DataFrame,
distance: str = "wasserstein",
early_late: bool = False,
negative_trajectory: bool = False,
) -> pd.DataFrame
Calculate all possible patient-to-patient transitions for KIRC cancer.
This function computes pairwise distances between all patients at consecutive (or same) cancer stages, considering metadata constraints. Only patients with matching gender and race are considered as potential trajectory pairs.
Parameters:
| Name | Type | Description | Default |
data | DataFrame | Gene expression data with patients as columns. | required |
metadata_selection | DataFrame | Clinical metadata with columns: histological_type, race, gender, stage. | required |
distance | (wasserstein, euclidean) | Distance metric to use for calculating patient-to-patient distances. | 'wasserstein' |
early_late | bool | If True, uses early/late stage groupings. If False, uses I-IV stages. | False |
negative_trajectory | bool | If True, generates same-stage transitions (negative controls). If False, generates progression transitions (positive trajectories). | False |
Returns:
| Type | Description |
DataFrame | DataFrame containing all possible transitions with columns: - source, target: Patient IDs - source_gender, target_gender: Gender - source_race, target_race: Race - transition: Stage transition label (e.g., '1_to_2', 'early_to_late') - distance: Calculated distance between patients Sorted by gender, race, transition, and distance. |
Raises:
| Type | Description |
ValueError | If distance metric is not 'wasserstein' or 'euclidean'. |
Notes
- For positive trajectories: links I→II, II→III, III→IV or early→late
- For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
- Only patients with identical gender and race are paired
Source code in renalprog/modeling/predict.py
| def calculate_all_possible_transitions(
data: pd.DataFrame,
metadata_selection: pd.DataFrame,
distance: str = 'wasserstein',
early_late: bool = False,
negative_trajectory: bool = False
) -> pd.DataFrame:
"""
Calculate all possible patient-to-patient transitions for KIRC cancer.
This function computes pairwise distances between all patients at consecutive
(or same) cancer stages, considering metadata constraints. Only patients with
matching gender and race are considered as potential trajectory pairs.
Parameters
----------
data : pd.DataFrame
Gene expression data with patients as columns.
metadata_selection : pd.DataFrame
Clinical metadata with columns: histological_type, race, gender, stage.
distance : {'wasserstein', 'euclidean'}, default='wasserstein'
Distance metric to use for calculating patient-to-patient distances.
early_late : bool, default=False
If True, uses early/late stage groupings. If False, uses I-IV stages.
negative_trajectory : bool, default=False
If True, generates same-stage transitions (negative controls).
If False, generates progression transitions (positive trajectories).
Returns
-------
pd.DataFrame
DataFrame containing all possible transitions with columns:
- source, target: Patient IDs
- source_gender, target_gender: Gender
- source_race, target_race: Race
- transition: Stage transition label (e.g., '1_to_2', 'early_to_late')
- distance: Calculated distance between patients
Sorted by gender, race, transition, and distance.
Raises
------
ValueError
If distance metric is not 'wasserstein' or 'euclidean'.
Notes
-----
- For positive trajectories: links I→II, II→III, III→IV or early→late
- For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
- Only patients with identical gender and race are paired
"""
# Select distance function
if distance == 'wasserstein':
from scipy.stats import wasserstein_distance
distance_fun = wasserstein_distance
elif distance == 'euclidean':
from scipy.spatial.distance import euclidean
distance_fun = euclidean
else:
raise ValueError('Distance function not implemented. Use either "wasserstein" or "euclidean".')
# Define stage transitions based on parameters
if early_late and not negative_trajectory:
possible_transitions = ['early_to_late']
stage_pairs = [['early', 'late']]
elif early_late and negative_trajectory:
possible_transitions = ['early_to_early', 'late_to_late']
stage_pairs = [['early', 'early'], ['late', 'late']]
elif not early_late and not negative_trajectory:
possible_transitions = ['1_to_2', '2_to_3', '3_to_4']
stage_pairs = [['I', 'II'], ['II', 'III'], ['III', 'IV']]
elif not early_late and negative_trajectory:
possible_transitions = ['1_to_1', '2_to_2', '3_to_3', '4_to_4']
stage_pairs = [['I', 'I'], ['II', 'II'], ['III', 'III'], ['IV', 'IV']]
# Calculate all possible transitions
results = []
for i_tr, transition in enumerate(possible_transitions):
source_target_stage = stage_pairs[i_tr]
# Iterate through all patient pairs at specified stages
for pat_i in metadata_selection.index[metadata_selection['stage'] == source_target_stage[0]]:
for pat_ii in metadata_selection.index[metadata_selection['stage'] == source_target_stage[1]]:
# Extract metadata for both patients
source_gender = metadata_selection.at[pat_i, 'gender']
target_gender = metadata_selection.at[pat_ii, 'gender']
source_race = metadata_selection.at[pat_i, 'race']
target_race = metadata_selection.at[pat_ii, 'race']
# Skip if metadata doesn't match (gender and race must match)
if not (source_race == target_race and source_gender == target_gender):
continue
# Store transition information
results_i = {
'source': pat_i,
'target': pat_ii,
'source_gender': source_gender,
'target_gender': target_gender,
'source_race': source_race,
'target_race': target_race,
'transition': transition,
'distance': distance_fun(data[pat_i], data[pat_ii]),
}
results.append(results_i)
# Convert to DataFrame and sort
results_df = pd.DataFrame(results)
results_df.sort_values(
['source_gender', 'target_gender', 'source_race', 'target_race',
'transition', 'distance'],
inplace=True,
ignore_index=True
)
return results_df
|
Extract metadata from model directory.
get_metadata(test_path: Path) -> Dict[str, Any]
Extract metadata from test dataset in SDMetrics format.
This function loads a CSV file and extracts its column metadata using SDMetrics' automatic detection. The metadata describes the structure and data types of the dataset, which is required for SDMetrics quality evaluation.
Args: test_path: Path to the CSV file containing the test dataset. Can be a string or Path object.
Returns: Dictionary containing metadata with column names and data types. Format: {'columns': {col_name: {'sdtype': type}}}
Note: - The CSV is loaded with index_col=0 to avoid treating the index as a feature - Both real and synthetic data must share the same metadata structure - This ensures SDMetrics can properly validate and compare the datasets
Example: >>> metadata = get_metadata("data/X_test.csv") >>> print(metadata['columns'].keys()) dict_keys(['gene1', 'gene2', ...])
Source code in renalprog/modeling/predict.py
| def get_metadata(test_path: Path) -> Dict[str, Any]:
"""
Extract metadata from test dataset in SDMetrics format.
This function loads a CSV file and extracts its column metadata using SDMetrics'
automatic detection. The metadata describes the structure and data types of
the dataset, which is required for SDMetrics quality evaluation.
Args:
test_path: Path to the CSV file containing the test dataset.
Can be a string or Path object.
Returns:
Dictionary containing metadata with column names and data types.
Format: {'columns': {col_name: {'sdtype': type}}}
Note:
- The CSV is loaded with index_col=0 to avoid treating the index as a feature
- Both real and synthetic data must share the same metadata structure
- This ensures SDMetrics can properly validate and compare the datasets
Example:
>>> metadata = get_metadata("data/X_test.csv")
>>> print(metadata['columns'].keys())
dict_keys(['gene1', 'gene2', ...])
"""
from pathlib import Path as pathlib_Path
logger.info("Extracting metadata for SDMetrics evaluation")
logger.info(f"Loading data from: {test_path}")
# Convert to Path object for consistent handling across platforms
test_path = pathlib_Path(test_path)
# Load data using SDMetrics CSV handler
# CRITICAL: index_col=0 prevents the index from being treated as a feature column
# This would cause metadata mismatch errors if the index is included
connector = CSVHandler()
real_data = connector.read(
folder_name=str(test_path.parent),
file_names=[test_path.name],
read_csv_parameters={
'index_col': 0, # Use first column as index, not as feature
'parse_dates': False, # Don't parse dates (all numeric gene expression)
'encoding': 'latin-1' # Standard encoding for CSV files
}
)
# Auto-detect metadata from the loaded data
metadata = Metadata.detect_from_dataframes(data=real_data)
# Extract table-specific metadata (removes wrapper structure)
# The key 'X_test' matches the filename without extension
metadata_use = metadata.to_dict()['tables']['X_test']
logger.info(f"Extracted metadata for {len(metadata_use['columns'])} genes")
return metadata_use
|
Complete Example
import torch
import pandas as pd
from pathlib import Path
from renalprog.modeling.train import VAE
from renalprog.modeling.predict import (
apply_vae,
create_patient_connections,
generate_trajectories,
evaluate_reconstruction
)
from renalprog.plots import plot_latent_space, plot_trajectory
# Load model and data
model = VAE(input_dim=20000, mid_dim=1024, features=128)
model.load_state_dict(torch.load("models/my_vae/best_model.pt"))
train_expr = pd.read_csv("data/interim/split/train_expression.tsv", sep="\t", index_col=0)
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
clinical = pd.read_csv("data/interim/split/test_clinical.tsv", sep="\t", index_col=0)
# Encode data
train_results = apply_vae(model, train_expr.values, device='cuda')
test_results = apply_vae(model, test_expr.values, device='cuda')
# Visualize latent space
plot_latent_space(
latent=test_results['latent'],
labels=clinical['stage'],
output_path=Path("reports/figures/latent_space.png")
)
# Create patient connections
early_mask = clinical['stage'] == 'early'
late_mask = clinical['stage'] == 'late'
connections = create_patient_connections(
latent_early=test_results['latent'][early_mask],
latent_late=test_results['latent'][late_mask],
method='closest',
output_path=Path("data/processed/connections.csv")
)
# Generate trajectories
trajectories = generate_trajectories(
model=model,
start_data=test_expr.values[early_mask],
end_data=test_expr.values[late_mask],
n_steps=50,
interpolation='spherical',
device='cuda'
)
# Evaluate reconstruction
metrics = evaluate_reconstruction(
model=model,
original_data=test_expr.values,
device='cuda',
output_dir=Path("reports/reconstruction")
)
print(f"Generated {len(trajectories)} trajectories")
print(f"Reconstruction MSE: {metrics['mse']:.4f}")
See Also