Plots API¶
Visualization functions for gene expression analysis, model training, and results presentation.
Overview¶
The plots module provides publication-quality visualization for:
- Training history and loss curves
- Latent space representations
- Gene expression heatmaps
- Trajectories and pathways
- Confusion matrices
- Enrichment results
Core Plotting Functions¶
save_plot¶
Utility function for saving plots with consistent formatting.
save_plot ¶
save_plot(
fig: Figure,
save_path: Union[str, Path],
formats: List[str] = ["html", "png", "pdf", "svg"],
width: int = DEFAULT_WIDTH,
height: int = DEFAULT_HEIGHT,
) -> None
Save plotly figure in multiple formats.
Args: fig: Plotly figure object save_path: Base path for saving (without extension) formats: List of formats to save ['html', 'png', 'pdf', 'svg'] width: Width in pixels for static formats height: Height in pixels for static formats
Source code in renalprog/plots.py
Example Usage:
from renalprog.plots import save_plot
import matplotlib.pyplot as plt
from pathlib import Path
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot([1, 2, 3], [4, 5, 6])
ax.set_title("My Plot")
save_plot(
fig=fig,
output_path=Path("reports/figures/my_plot.png"),
dpi=300,
bbox_inches='tight'
)
Training Visualization¶
plot_training_history¶
Visualize VAE training progress.
plot_training_history ¶
plot_training_history(
history: Dict[str, List[float]],
save_path: Optional[Path] = None,
title: str = "Training History",
log_scale: bool = False,
) -> go.Figure
Plot training and validation losses over epochs.
Args: history: Dictionary with 'train_loss' and 'val_loss' keys save_path: Optional path to save figure title: Plot title log_scale: Whether to use log scale for y-axis
Returns: Plotly Figure object
Source code in renalprog/plots.py
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |
Example:
from renalprog.plots import plot_training_history
from pathlib import Path
# After training
history, model, checkpoints = train_vae(...)
# Plot training curves
plot_training_history(
history=history,
output_path=Path("reports/figures/training_history.png"),
title="VAE Training Progress"
)
plot_reconstruction_losses¶
Compare reconstruction losses across samples.
plot_reconstruction_losses ¶
plot_reconstruction_losses(
loss_train: List[float],
loss_test: List[float],
save_path: Optional[Path] = None,
title: str = "Reconstruction Network Losses",
) -> go.Figure
Plot training and test losses for reconstruction network.
Args: loss_train: List of training losses loss_test: List of test losses save_path: Optional path to save figure title: Plot title
Returns: Plotly Figure object
Source code in renalprog/plots.py
plot_umap_plotly¶
Interactive UMAP visualization.
plot_umap_plotly ¶
plot_umap_plotly(
data,
clinical,
colors_dict,
shapes_dict=None,
n_components=2,
save_fig=False,
save_as=None,
seed=None,
title="UMAP",
show=True,
marker_size=8,
)
Plot UMAP of the data with Plotly using different colors for the different groups.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data | DataFrame | Features as rows and samples as columns (same as in plot_umap). | required |
clinical | Series | Category per sample (index must match data.columns). | required |
colors_dict | dict | Mapping {group_name: color_hex_or_name}. | required |
shapes_dict | Mapping {group_name: shape}. | None | |
n_components | int | 2 or 3, by default 2. | 2 |
save_fig | bool | If True, save HTML/PNG/PDF/SVG, by default False. | False |
save_as | str or None | Base path (without extension) for saving, by default None. | None |
seed | int or None | Random state for UMAP, by default None. | None |
title | str | Plot title, by default 'UMAP'. | 'UMAP' |
show | bool | If True, display the plot, by default True. | True |
Source code in renalprog/plots.py
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 | |
Example:
from renalprog.plots import plot_umap_plotly
# Create interactive plot
fig = plot_umap_plotly(
latent=latent,
labels=clinical['stage'],
sample_names=clinical.index.tolist(),
title="Interactive Latent Space"
)
# Save as HTML
fig.write_html("reports/figures/latent_space_interactive.html")
Trajectory Visualization¶
plot_trajectory¶
Visualize disease progression trajectory.
plot_trajectory ¶
plot_trajectory(
trajectory: ndarray,
gene_names: Optional[List[str]] = None,
save_path: Optional[Path] = None,
title: str = "Gene Expression Trajectory",
n_genes_to_show: int = 20,
) -> go.Figure
Plot gene expression changes along a trajectory.
Args: trajectory: Array of shape (n_timepoints, n_genes) gene_names: Optional list of gene names save_path: Optional path to save figure title: Plot title n_genes_to_show: Number of top varying genes to display
Returns: Plotly Figure object
Source code in renalprog/plots.py
Example:
from renalprog.plots import plot_trajectory
# Plot single trajectory
plot_trajectory(
trajectory=trajectories[0], # Shape: (n_steps, n_genes)
feature_names=selected_genes,
output_path=Path("reports/figures/trajectory_001.png"),
title="Disease Progression Trajectory",
highlight_genes=['TP53', 'VEGFA', 'HIF1A']
)
PCA Visualization¶
plot_pca_variance¶
Visualize PCA variance explained.
plot_pca_variance ¶
plot_pca_variance(
explained_variance_ratio: ndarray,
save_path: Optional[Path] = None,
title: str = "PCA Explained Variance",
n_components: int = 20,
) -> go.Figure
Plot explained variance ratio from PCA.
Args: explained_variance_ratio: Array of explained variance ratios save_path: Optional path to save figure title: Plot title n_components: Number of components to show
Returns: Plotly Figure object
Source code in renalprog/plots.py
Example:
from renalprog.plots import plot_pca_variance
from sklearn.decomposition import PCA
# Perform PCA
pca = PCA(n_components=50)
pca.fit(expression_data)
# Plot variance explained
plot_pca_variance(
pca=pca,
output_path=Path("reports/figures/pca_variance.png"),
n_components=20
)
Complete Visualization Workflow¶
import torch
import pandas as pd
from pathlib import Path
from renalprog.modeling.train import VAE, train_vae
from renalprog.modeling.predict import apply_vae, generate_trajectories
from renalprog.plots import (
plot_training_history,
plot_trajectory,
plot_umap_plotly
)
# Create output directory
output_dir = Path("reports/figures")
output_dir.mkdir(parents=True, exist_ok=True)
# 1. Load data
train_expr = pd.read_csv("data/interim/split/train_expression.tsv", sep="\t", index_col=0)
test_expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
clinical = pd.read_csv("data/interim/split/test_clinical.tsv", sep="\t", index_col=0)
# 2. Train model and plot history
history, model, checkpoints = train_vae(
train_data=train_expr.values,
val_data=test_expr.values,
input_dim=train_expr.shape[1],
mid_dim=1024,
features=128,
output_dir=Path("models/my_vae"),
n_epochs=100
)
plot_training_history(
history=history,
output_path=output_dir / "training_history.png"
)
# 3. Encode to latent space and visualize
results = apply_vae(model, test_expr.values, device='cuda')
plot_umap_plotly(
latent=results['latent'],
labels=clinical['stage'],
sample_names=clinical.index.tolist(),
title="Interactive Latent Space"
).write_html(output_dir / "latent_space_interactive.html")
# 4. Generate and plot trajectories
early_mask = clinical['stage'] == 'early'
late_mask = clinical['stage'] == 'late'
trajectories = generate_trajectories(
model=model,
start_data=test_expr.values[early_mask],
end_data=test_expr.values[late_mask],
n_steps=50,
device='cuda'
)
# Plot first trajectory
plot_trajectory(
trajectory=trajectories[0],
feature_names=top_genes.tolist(),
output_path=output_dir / "trajectory_001.png",
title="Disease Progression Trajectory"
)
print(f"All figures saved to {output_dir}")
See Also¶
- Training API - Generate training history
- Prediction API - Generate predictions to plot
- Trajectories API - Generate trajectories
- Complete Pipeline Tutorial