Trajectories API
Functions for analyzing disease progression trajectories.
Overview
The trajectories module provides analysis tools for:
- Trajectory network construction
- Patient connectivity analysis
- Temporal pathway enrichment
- Transition probability calculation
- Trajectory visualization
Trajectory Generation
generate_trajectories
Generate smooth disease progression trajectories.
generate_trajectories
generate_trajectories(
model: Module,
source_samples: DataFrame,
target_samples: DataFrame,
n_steps: int = 50,
method: str = "linear",
output_dir: Optional[Path] = None,
parallel: bool = False,
n_workers: Optional[int] = None,
) -> Dict[str, pd.DataFrame]
Generate synthetic trajectories between source and target patient samples.
This function creates interpolated gene expression profiles in the latent space between pairs of patients at different cancer stages.
Args: model: Trained VAE model source_samples: Source patient samples (early stage) target_samples: Target patient samples (late stage) n_steps: Number of interpolation steps method: Interpolation method ("linear" or "spherical") output_dir: Optional directory to save trajectories parallel: Whether to use parallel processing n_workers: Number of parallel workers (None = use all CPUs)
Returns: Dictionary mapping patient pairs to trajectory DataFrames
Source code in renalprog/modeling/predict.py
| def generate_trajectories(
model: torch.nn.Module,
source_samples: pd.DataFrame,
target_samples: pd.DataFrame,
n_steps: int = 50,
method: str = "linear",
output_dir: Optional[Path] = None,
parallel: bool = False,
n_workers: Optional[int] = None
) -> Dict[str, pd.DataFrame]:
"""
Generate synthetic trajectories between source and target patient samples.
This function creates interpolated gene expression profiles in the latent
space between pairs of patients at different cancer stages.
Args:
model: Trained VAE model
source_samples: Source patient samples (early stage)
target_samples: Target patient samples (late stage)
n_steps: Number of interpolation steps
method: Interpolation method ("linear" or "spherical")
output_dir: Optional directory to save trajectories
parallel: Whether to use parallel processing
n_workers: Number of parallel workers (None = use all CPUs)
Returns:
Dictionary mapping patient pairs to trajectory DataFrames
"""
logger.info(f"Generating trajectories with {n_steps} steps using {method} interpolation")
# TODO: Implement trajectory generation
# Migrate from src_deseq_and_gsea_NCSR/synthetic_data_generation.py
raise NotImplementedError(
"generate_trajectories() needs implementation from "
"src_deseq_and_gsea_NCSR/synthetic_data_generation.py and "
"src/data/fun_interpol.py"
)
|
Network Construction
build_trajectory_network
Build directed graph of patient transitions.
build_trajectory_network
build_trajectory_network(
patient_links: DataFrame,
) -> Tuple[Dict[str, List[str]], List[List[str]]]
Build trajectory network and find all complete disease progression paths.
Constructs a directed graph from patient links and identifies all possible complete trajectories from root nodes (earliest stage patients not appearing as targets) to leaf nodes (latest stage patients not appearing as sources).
Args: patient_links: DataFrame with 'source' and 'target' columns from linking functions
Returns: Tuple of: - network: Dict mapping each source patient to list of target patients - trajectories: List of complete trajectories, where each trajectory is a list of patient IDs ordered from earliest to latest stage
Network Structure: - Adjacency list representation: {source: [target1, target2, ...]} - Directed edges from earlier to later stages - Allows multiple outgoing edges (one patient → multiple next-stage patients)
Trajectory Discovery: - Uses depth-first search from root nodes - Root nodes: Patients in 'source' but not in 'target' (stage I or early) - Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late) - Each trajectory represents a complete disease progression path
Example: >>> network, trajectories = build_trajectory_network(patient_links) >>> print(f"Network has {len(network)} nodes") >>> print(f"Found {len(trajectories)} complete trajectories") >>> print(f"Example trajectory: {trajectories[0]}") Network has 500 nodes Found 234 complete trajectories Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']
Trajectory Characteristics: - Length varies based on how many stages the path spans - Typical lengths: 2-4 patients for I→II→III→IV progressions - Length 2 for early→late progressions - Patients can appear in multiple trajectories
Note: - Cycles are prevented during trajectory search - All paths from root to leaf are enumerated - Trajectories respect chronological disease progression
Source code in renalprog/modeling/predict.py
| def build_trajectory_network(
patient_links: pd.DataFrame
) -> Tuple[Dict[str, List[str]], List[List[str]]]:
"""
Build trajectory network and find all complete disease progression paths.
Constructs a directed graph from patient links and identifies all possible
complete trajectories from root nodes (earliest stage patients not appearing
as targets) to leaf nodes (latest stage patients not appearing as sources).
Args:
patient_links: DataFrame with 'source' and 'target' columns from linking functions
Returns:
Tuple of:
- network: Dict mapping each source patient to list of target patients
- trajectories: List of complete trajectories, where each trajectory is a
list of patient IDs ordered from earliest to latest stage
Network Structure:
- Adjacency list representation: {source: [target1, target2, ...]}
- Directed edges from earlier to later stages
- Allows multiple outgoing edges (one patient → multiple next-stage patients)
Trajectory Discovery:
- Uses depth-first search from root nodes
- Root nodes: Patients in 'source' but not in 'target' (stage I or early)
- Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late)
- Each trajectory represents a complete disease progression path
Example:
>>> network, trajectories = build_trajectory_network(patient_links)
>>> print(f"Network has {len(network)} nodes")
>>> print(f"Found {len(trajectories)} complete trajectories")
>>> print(f"Example trajectory: {trajectories[0]}")
Network has 500 nodes
Found 234 complete trajectories
Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']
Trajectory Characteristics:
- Length varies based on how many stages the path spans
- Typical lengths: 2-4 patients for I→II→III→IV progressions
- Length 2 for early→late progressions
- Patients can appear in multiple trajectories
Note:
- Cycles are prevented during trajectory search
- All paths from root to leaf are enumerated
- Trajectories respect chronological disease progression
"""
logger.info("Building trajectory network from patient links")
sources = patient_links['source']
targets = patient_links['target']
# Build network adjacency list
network = {}
for source, target in zip(sources, targets):
if source not in network:
network[source] = []
network[source].append(target)
logger.info(f"Network built: {len(network)} source nodes")
# Find root nodes (patients who are sources but never targets)
unique_sources = set(sources) - set(targets)
logger.info(f"Found {len(unique_sources)} root nodes (earliest stage patients)")
# Recursively find all trajectories from each root
def find_trajectories(start_node: str, visited: Optional[List[str]] = None) -> List[List[str]]:
"""Depth-first search to find all paths from start_node to leaf nodes."""
if visited is None:
visited = []
visited.append(start_node)
# If node has no outgoing edges, this is a leaf node - return path
if start_node not in network:
return [visited]
# Recursively explore all targets
trajectories = []
for target in network[start_node]:
if target not in visited: # Avoid cycles
new_visited = visited.copy()
trajectories.extend(find_trajectories(target, new_visited))
return trajectories
# Find all trajectories starting from each root
all_trajectories = []
if len(unique_sources) == 0:
# No clear root nodes - this happens with early→late transitions where
# patients can be both sources and targets. In this case, each source→target
# pair is already a complete 2-patient trajectory.
logger.info("No root nodes found (typical for early→late transitions).")
logger.info("Using each source→target pair as a complete trajectory.")
for source, target in zip(sources, targets):
all_trajectories.append([source, target])
else:
# Standard case: multi-stage progressions (I→II→III→IV)
for source in unique_sources:
all_trajectories.extend(find_trajectories(source))
logger.info(f"Discovered {len(all_trajectories)} complete disease progression trajectories")
# Log trajectory length statistics only if we have trajectories
if len(all_trajectories) > 0:
traj_lengths = [len(t) for t in all_trajectories]
logger.info(f"Trajectory lengths - Min: {min(traj_lengths)}, Max: {max(traj_lengths)}, "
f"Mean: {np.mean(traj_lengths):.1f}")
else:
logger.warning("No trajectories found!")
return network, all_trajectories
|
Example Usage:
from renalprog.modeling.predict import build_trajectory_network
import pandas as pd
from pathlib import Path
# Load patient connections
connections = pd.read_csv("data/processed/patient_connections.csv")
# Build network
network = build_trajectory_network(
connections=connections,
output_path=Path("data/processed/trajectory_network.graphml")
)
print(f"Network has {network.number_of_nodes()} nodes")
print(f"Network has {network.number_of_edges()} edges")
generate_trajectory_data
Generate complete trajectory dataset with metadata.
generate_trajectory_data
generate_trajectory_data(
vae_model: Module,
recnet_model: Optional[Module],
trajectory: List[str],
gene_data: DataFrame,
n_timepoints: int = 50,
interpolation_method: str = "linear",
device: str = "cpu",
save_path: Optional[Path] = None,
scaler: Optional[MinMaxScaler] = None,
) -> pd.DataFrame
Generate synthetic gene expression data along a patient trajectory.
Creates N interpolated time points between consecutive patients in a trajectory by performing interpolation in the VAE latent space, then decoding back to gene expression space. Optionally applies reconstruction network for refinement.
Args: vae_model: Trained VAE model for encoding/decoding recnet_model: Optional reconstruction network for refining VAE output trajectory: List of patient IDs in chronological progression order gene_data: Gene expression DataFrame (genes × patients) n_timepoints: Number of interpolation points between each patient pair interpolation_method: 'linear' or 'spherical' interpolation in latent space device: Torch device for computation save_path: Optional path to save trajectory CSV file scaler: Pre-fitted MinMaxScaler from VAE training. If None, will fit on gene_data.
Returns: DataFrame with synthetic gene expression profiles for all time points. Shape: (n_timepoints * (len(trajectory)-1), n_genes) Index contains time point identifiers
Workflow: 1. Extract gene expression for each patient in trajectory 2. Normalize using the SAME scaler used during VAE training 3. Encode each patient to VAE latent space 4. For each consecutive pair: a. Interpolate in latent space (linear or spherical) b. Decode interpolated points back to gene space c. Optionally apply reconstruction network 5. Concatenate all segments into complete trajectory
Interpolation Methods: linear: Straight-line interpolation in latent space z(t) = (1-t)*z_source + t*z_target
spherical: Spherical linear interpolation (SLERP)
Preserves magnitude, interpolates on hypersphere
Recommended for normalized latent spaces
Note: CRITICAL: The scaler must be the same one used during VAE training. Using a different scaler will produce incorrect latent representations. If scaler=None, will fit on all gene_data (all patients), which approximates the training distribution.
Source code in renalprog/modeling/predict.py
| def generate_trajectory_data(
vae_model: torch.nn.Module,
recnet_model: Optional[torch.nn.Module],
trajectory: List[str],
gene_data: pd.DataFrame,
n_timepoints: int = 50,
interpolation_method: str = 'linear',
device: str = 'cpu',
save_path: Optional[Path] = None,
scaler: Optional[MinMaxScaler] = None
) -> pd.DataFrame:
"""
Generate synthetic gene expression data along a patient trajectory.
Creates N interpolated time points between consecutive patients in a trajectory
by performing interpolation in the VAE latent space, then decoding back to
gene expression space. Optionally applies reconstruction network for refinement.
Args:
vae_model: Trained VAE model for encoding/decoding
recnet_model: Optional reconstruction network for refining VAE output
trajectory: List of patient IDs in chronological progression order
gene_data: Gene expression DataFrame (genes × patients)
n_timepoints: Number of interpolation points between each patient pair
interpolation_method: 'linear' or 'spherical' interpolation in latent space
device: Torch device for computation
save_path: Optional path to save trajectory CSV file
scaler: Pre-fitted MinMaxScaler from VAE training. If None, will fit on gene_data.
Returns:
DataFrame with synthetic gene expression profiles for all time points.
Shape: (n_timepoints * (len(trajectory)-1), n_genes)
Index contains time point identifiers
Workflow:
1. Extract gene expression for each patient in trajectory
2. Normalize using the SAME scaler used during VAE training
3. Encode each patient to VAE latent space
4. For each consecutive pair:
a. Interpolate in latent space (linear or spherical)
b. Decode interpolated points back to gene space
c. Optionally apply reconstruction network
5. Concatenate all segments into complete trajectory
Interpolation Methods:
linear: Straight-line interpolation in latent space
z(t) = (1-t)*z_source + t*z_target
spherical: Spherical linear interpolation (SLERP)
Preserves magnitude, interpolates on hypersphere
Recommended for normalized latent spaces
Note:
CRITICAL: The scaler must be the same one used during VAE training.
Using a different scaler will produce incorrect latent representations.
If scaler=None, will fit on all gene_data (all patients), which approximates
the training distribution.
"""
logger.info(f"Generating trajectory data for {len(trajectory)} patients")
logger.info(f"Interpolation: {n_timepoints} points × {len(trajectory)-1} segments")
logger.info(f"Method: {interpolation_method}")
# Set models to evaluation mode
vae_model.eval()
if recnet_model is not None:
recnet_model.eval()
vae_model = vae_model.to(device)
if recnet_model is not None:
recnet_model = recnet_model.to(device)
# Use provided scaler or fit new one on all gene data
if scaler is None:
logger.warning("No scaler provided - fitting new scaler on all gene data")
logger.warning("This may not match VAE training normalization!")
scaler = MinMaxScaler()
# gene_data is (genes × patients), need (patients × genes) for scaler
scaler.fit(gene_data.T.values)
logger.info(f"Fitted scaler on {gene_data.shape[1]} patients")
else:
logger.info("Using provided scaler from VAE training")
# Select interpolation function
if interpolation_method == 'linear':
interp_func = interpolate_latent_linear
elif interpolation_method == 'spherical':
interp_func = interpolate_latent_spherical
else:
raise ValueError(f"Unknown interpolation method: {interpolation_method}")
# Generate synthetic data for each segment of the trajectory
all_segments = []
with torch.no_grad():
for i in range(len(trajectory) - 1):
source_patient = trajectory[i]
target_patient = trajectory[i + 1]
logger.info(f"Segment {i+1}/{len(trajectory)-1}: {source_patient} → {target_patient}")
# Get gene expression for source and target
# gene_data is (genes × patients), so gene_data[patient] is a Series of gene values
source_expr = gene_data[source_patient].values.reshape(1, -1) # (1, genes)
target_expr = gene_data[target_patient].values.reshape(1, -1) # (1, genes)
# Normalize data using the provided scaler
# Scaler expects (n_samples, n_features) = (1, genes)
source_norm = scaler.transform(source_expr) # (1, genes)
target_norm = scaler.transform(target_expr) # (1, genes)
# Encode to latent space
source_tensor = torch.tensor(source_norm, dtype=torch.float32).to(device)
target_tensor = torch.tensor(target_norm, dtype=torch.float32).to(device)
_, _, _, z_source = vae_model(source_tensor)
_, _, _, z_target = vae_model(target_tensor)
# Interpolate in latent space
z_source_np = z_source.cpu().numpy().flatten()
z_target_np = z_target.cpu().numpy().flatten()
interpolated_z = interp_func(z_source_np, z_target_np, n_timepoints)
# Decode interpolated latent vectors
interpolated_z_tensor = torch.tensor(interpolated_z, dtype=torch.float32).to(device)
decoded = vae_model.decoder(interpolated_z_tensor)
# Denormalize using the same scaler
# decoded is (n_timepoints, genes), scaler expects (n_samples, n_features)
decoded_np = decoded.cpu().numpy() # (n_timepoints, genes)
segment_data = scaler.inverse_transform(decoded_np) # (n_timepoints, genes) - REAL SPACE
# Apply reconstruction network if provided
# CRITICAL: RecNet works on REAL SPACE data, not normalized!
if recnet_model is not None:
# Convert to tensor and apply RecNet directly to real space data
segment_tensor = torch.tensor(segment_data, dtype=torch.float32).to(device)
refined = recnet_model(segment_tensor)
segment_data = refined.cpu().numpy() # (n_timepoints, genes) - REAL SPACE
all_segments.append(segment_data)
# Concatenate all segments
trajectory_data = np.vstack(all_segments)
# Create DataFrame
trajectory_df = pd.DataFrame(
trajectory_data,
columns=gene_data.index
)
# Create informative index
time_indices = []
for i in range(len(trajectory) - 1):
for t in range(n_timepoints):
time_indices.append(f"{trajectory[i]}_to_{trajectory[i+1]}_t{t:03d}")
trajectory_df.index = time_indices
logger.info(f"Generated trajectory data: {trajectory_df.shape}")
# Save if path provided
if save_path is not None:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
trajectory_df.to_csv(save_path)
logger.info(f"Saved trajectory to: {save_path}")
return trajectory_df
|
Patient Connectivity
create_patient_connections
Create optimal patient pairings for trajectories.
create_patient_connections
create_patient_connections(
data: DataFrame,
clinical: Series,
method: str = "random",
transition_type: str = "early_to_late",
n_connections: Optional[int] = None,
seed: int = 2023,
) -> pd.DataFrame
Create connections between patients for trajectory generation.
Args: data: Gene expression data clinical: Clinical stage information method: Method for creating connections ("random", "nearest", "all") transition_type: Type of transition ("early_to_late", "early_to_early", etc.) n_connections: Number of connections to create (None = all possible) seed: Random seed
Returns: DataFrame with columns: source, target, transition
Source code in renalprog/modeling/predict.py
| def create_patient_connections(
data: pd.DataFrame,
clinical: pd.Series,
method: str = "random",
transition_type: str = "early_to_late",
n_connections: Optional[int] = None,
seed: int = 2023
) -> pd.DataFrame:
"""
Create connections between patients for trajectory generation.
Args:
data: Gene expression data
clinical: Clinical stage information
method: Method for creating connections ("random", "nearest", "all")
transition_type: Type of transition ("early_to_late", "early_to_early", etc.)
n_connections: Number of connections to create (None = all possible)
seed: Random seed
Returns:
DataFrame with columns: source, target, transition
"""
logger.info(f"Creating patient connections: {transition_type} using {method} method")
# TODO: Implement connection logic
# Migrate from notebooks/4_1_trajectories.ipynb
raise NotImplementedError(
"create_patient_connections() needs implementation from "
"notebooks/4_1_trajectories.ipynb"
)
|
link_patients_closest
Link patients using closest latent space neighbors.
link_patients_closest
link_patients_closest(
transitions_df: DataFrame,
start_with_first_stage: bool = True,
early_late: bool = False,
closest: bool = True,
) -> pd.DataFrame
Link patients by selecting closest (or farthest) matches across stages.
For each patient at a source stage, this function identifies the closest (or farthest) patient at the target stage, considering metadata constraints (gender, race). This creates one-to-one patient linkages that form the basis for trajectory construction.
Args: transitions_df: DataFrame from calculate_all_possible_transitions() containing all possible patient pairs with distances start_with_first_stage: If True, build forward trajectories (early→late) If False, build backward trajectories (late→early) early_late: If True, uses early/late groupings. If False, uses I-IV stages closest: If True, connect closest patients. If False, connect farthest patients
Returns: DataFrame with selected patient links, containing one row per source patient with their optimal target patient match. Includes all columns from transitions_df.
Selection Strategy: - Forward (start_with_first_stage=True): For each source, find optimal target - Backward (start_with_first_stage=False): For each target, find optimal source - Closest (closest=True): Minimum distance match - Farthest (closest=False): Maximum distance match
Metadata Stratification: Links are selected independently within each combination of: - Gender (MALE, FEMALE) - Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE) This ensures demographic consistency in trajectories.
Example: >>> links = link_patients_closest( ... transitions_df=all_transitions, ... start_with_first_stage=True, ... closest=True ... ) >>> print(f"Created {len(links)} patient links") Created 234 patient links
Note: - Processes transitions in order for forward: I→II→III→IV - Processes in reverse for backward: IV→III→II→I - Each patient appears at most once as a source in the result
Source code in renalprog/modeling/predict.py
| def link_patients_closest(
transitions_df: pd.DataFrame,
start_with_first_stage: bool = True,
early_late: bool = False,
closest: bool = True
) -> pd.DataFrame:
"""
Link patients by selecting closest (or farthest) matches across stages.
For each patient at a source stage, this function identifies the closest
(or farthest) patient at the target stage, considering metadata constraints
(gender, race). This creates one-to-one patient linkages that form the basis
for trajectory construction.
Args:
transitions_df: DataFrame from calculate_all_possible_transitions()
containing all possible patient pairs with distances
start_with_first_stage: If True, build forward trajectories (early→late)
If False, build backward trajectories (late→early)
early_late: If True, uses early/late groupings. If False, uses I-IV stages
closest: If True, connect closest patients. If False, connect farthest patients
Returns:
DataFrame with selected patient links, containing one row per source patient
with their optimal target patient match. Includes all columns from transitions_df.
Selection Strategy:
- Forward (start_with_first_stage=True): For each source, find optimal target
- Backward (start_with_first_stage=False): For each target, find optimal source
- Closest (closest=True): Minimum distance match
- Farthest (closest=False): Maximum distance match
Metadata Stratification:
Links are selected independently within each combination of:
- Gender (MALE, FEMALE)
- Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE)
This ensures demographic consistency in trajectories.
Example:
>>> links = link_patients_closest(
... transitions_df=all_transitions,
... start_with_first_stage=True,
... closest=True
... )
>>> print(f"Created {len(links)} patient links")
Created 234 patient links
Note:
- Processes transitions in order for forward: I→II→III→IV
- Processes in reverse for backward: IV→III→II→I
- Each patient appears at most once as a source in the result
"""
logger.info("Linking patients by closest/farthest matches")
logger.info(f"Direction: {'Forward' if start_with_first_stage else 'Backward'}")
logger.info(f"Strategy: {'Closest' if closest else 'Farthest'}")
# Define transition order based on direction
if start_with_first_stage and not early_late:
transitions_possible = ['1_to_2', '2_to_3', '3_to_4']
elif not start_with_first_stage and not early_late:
transitions_possible = ['3_to_4', '2_to_3', '1_to_2']
elif early_late:
transitions_possible = ['early_to_late']
# 0 for closest (smallest distance), -1 for farthest (largest distance)
idx = 0 if closest else -1
# Find closest/farthest patient for each source patient
closest_list = []
for transition_i in transitions_possible:
transition_df_i = transitions_df[transitions_df['transition'] == transition_i]
logger.info(f"Processing transition {transition_i}: {len(transition_df_i)} pairs")
# Iterate through all metadata combinations
for gender_i in ['FEMALE', 'MALE']:
df_gender_i = transition_df_i.query(f"source_gender == '{gender_i}'")
for race_i in ['ASIAN', 'BLACK OR AFRICAN AMERICAN', 'WHITE']:
df_race_i = df_gender_i.query(f"source_race == '{race_i}'")
if df_race_i.empty:
continue
# Get unique patients to link
unique_sources = df_race_i['source'].unique()
unique_targets = df_race_i['target'].unique()
use_uniques = unique_sources if start_with_first_stage else unique_targets
use_column = 'source' if start_with_first_stage else 'target'
# Find closest/farthest match for each patient
for pat_i in use_uniques:
pat_matches = df_race_i[df_race_i[use_column] == pat_i]
if len(pat_matches) > 0:
# Sort by distance and select first (closest) or last (farthest)
best_match = pat_matches.sort_values('distance').iloc[idx]
closest_list.append(best_match)
# Convert to DataFrame
closest_df = pd.DataFrame(closest_list)
closest_df.reset_index(drop=True, inplace=True)
logger.info(f"Created {len(closest_df)} patient links")
return closest_df
|
Example:
from renalprog.modeling.predict import link_patients_closest
import numpy as np
early_latent = np.random.randn(100, 128)
late_latent = np.random.randn(80, 128)
connections = link_patients_closest(
latent_early=early_latent,
latent_late=late_latent,
patient_ids_early=['E001', 'E002', ...],
patient_ids_late=['L001', 'L002', ...]
)
# Returns DataFrame with columns: early_patient, late_patient, distance
link_patients_random
Link patients randomly (control method).
link_patients_random
link_patients_random(
results_df: DataFrame,
start_with_first_stage: bool = True,
link_next: int = 5,
transitions_possible: Optional[List[str]] = None,
) -> pd.DataFrame
Link patients to multiple random targets at the next stage.
Instead of linking each patient to only their closest match, this function randomly samples multiple patients at the next stage to link to each source patient. This creates a one-to-many mapping useful for generating multiple trajectory samples.
Parameters:
| Name | Type | Description | Default |
results_df | DataFrame | DataFrame with possible sources and targets, their metadata, and distance. | required |
start_with_first_stage | bool | If True, initiate trajectories with first stage as sources. If False, initiate trajectories with last stage as sources. | True |
link_next | int | Number of patients at next stage to randomly link to each patient of current stage. | 5 |
transitions_possible | list | List of transitions to process (e.g., ['1_to_2', '2_to_3']). If None, defaults to ['early_to_late']. | None |
Returns:
| Type | Description |
DataFrame | DataFrame with randomly sampled patient links for each transition. Contains multiple rows per source patient (up to link_next). |
Notes
- Random sampling is primarily performed for WHITE race patients due to sample size
- If fewer than link_next targets are available, all available targets are selected
- Patients from other races are included with all their possible connections
- Empty DataFrame is returned if no WHITE patients are found
Source code in renalprog/modeling/predict.py
| def link_patients_random(
results_df: pd.DataFrame,
start_with_first_stage: bool = True,
link_next: int = 5,
transitions_possible: Optional[List[str]] = None
) -> pd.DataFrame:
"""
Link patients to multiple random targets at the next stage.
Instead of linking each patient to only their closest match, this function randomly
samples multiple patients at the next stage to link to each source patient. This
creates a one-to-many mapping useful for generating multiple trajectory samples.
Parameters
----------
results_df : pd.DataFrame
DataFrame with possible sources and targets, their metadata, and distance.
start_with_first_stage : bool, default=True
If True, initiate trajectories with first stage as sources.
If False, initiate trajectories with last stage as sources.
link_next : int, default=5
Number of patients at next stage to randomly link to each patient of current stage.
transitions_possible : list, optional
List of transitions to process (e.g., ['1_to_2', '2_to_3']).
If None, defaults to ['early_to_late'].
Returns
-------
pd.DataFrame
DataFrame with randomly sampled patient links for each transition.
Contains multiple rows per source patient (up to link_next).
Notes
-----
- Random sampling is primarily performed for WHITE race patients due to sample size
- If fewer than link_next targets are available, all available targets are selected
- Patients from other races are included with all their possible connections
- Empty DataFrame is returned if no WHITE patients are found
"""
# Set default transitions if not provided
if transitions_possible is None:
transitions_possible = ['early_to_late']
# Get unique genders and races
unique_genders = results_df['source_gender'].unique().tolist()
# Get unique races
unique_races = results_df['source_race'].unique().tolist()
if 'WHITE' in unique_races:
unique_races.remove('WHITE')
# transition:
samples = []
for transition_i in transitions_possible:
transition_df_i = results_df[results_df['transition'] == transition_i]
for gender_i in unique_genders:
df_samples_i = transition_df_i.query(
f"source_gender == '{gender_i}' & source_race == 'WHITE'") # we can only do this for the whites since these are the only ones with enough samples
if df_samples_i.empty:
print(f"Warning: No WHITE patients found for gender {gender_i} in transition {transition_i}")
continue
unique_sources_i = np.unique(df_samples_i['source']).tolist()
unique_targets_i = np.unique(df_samples_i['target']).tolist()
use_uniques = unique_sources_i if start_with_first_stage else unique_targets_i
use_source_target = 'source' if start_with_first_stage else 'target'
for pat_i in use_uniques:
sample_i = df_samples_i.loc[df_samples_i[use_source_target] == pat_i]
if len(sample_i) >= link_next:
sample_i = sample_i.sample(
link_next) # Sample a number of patients at next stage to link to each patient of current stage
else:
sample_i = sample_i.sample(len(sample_i)) # Sample all available patients if less than link_next
samples.append(sample_i)
# Check if samples list is empty
if not samples:
print("Warning: No samples found for WHITE race. Returning empty DataFrame.")
return pd.DataFrame(columns=results_df.columns)
# Turn samples into dataframe:
samples_df = pd.concat(samples)
# Add the rest of the races
if unique_races:
samples_df = pd.concat(
[
samples_df,
results_df[results_df['source_race'].isin(unique_races)]
]
)
samples_df.reset_index(drop=True, inplace=True)
return samples_df
|
Transition Analysis
calculate_all_possible_transitions
Calculate metrics for all possible patient transitions.
calculate_all_possible_transitions
calculate_all_possible_transitions(
data: DataFrame,
metadata_selection: DataFrame,
distance: str = "wasserstein",
early_late: bool = False,
negative_trajectory: bool = False,
) -> pd.DataFrame
Calculate all possible patient-to-patient transitions for KIRC cancer.
This function computes pairwise distances between all patients at consecutive (or same) cancer stages, considering metadata constraints. Only patients with matching gender and race are considered as potential trajectory pairs.
Parameters:
| Name | Type | Description | Default |
data | DataFrame | Gene expression data with patients as columns. | required |
metadata_selection | DataFrame | Clinical metadata with columns: histological_type, race, gender, stage. | required |
distance | (wasserstein, euclidean) | Distance metric to use for calculating patient-to-patient distances. | 'wasserstein' |
early_late | bool | If True, uses early/late stage groupings. If False, uses I-IV stages. | False |
negative_trajectory | bool | If True, generates same-stage transitions (negative controls). If False, generates progression transitions (positive trajectories). | False |
Returns:
| Type | Description |
DataFrame | DataFrame containing all possible transitions with columns: - source, target: Patient IDs - source_gender, target_gender: Gender - source_race, target_race: Race - transition: Stage transition label (e.g., '1_to_2', 'early_to_late') - distance: Calculated distance between patients Sorted by gender, race, transition, and distance. |
Raises:
| Type | Description |
ValueError | If distance metric is not 'wasserstein' or 'euclidean'. |
Notes
- For positive trajectories: links I→II, II→III, III→IV or early→late
- For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
- Only patients with identical gender and race are paired
Source code in renalprog/modeling/predict.py
| def calculate_all_possible_transitions(
data: pd.DataFrame,
metadata_selection: pd.DataFrame,
distance: str = 'wasserstein',
early_late: bool = False,
negative_trajectory: bool = False
) -> pd.DataFrame:
"""
Calculate all possible patient-to-patient transitions for KIRC cancer.
This function computes pairwise distances between all patients at consecutive
(or same) cancer stages, considering metadata constraints. Only patients with
matching gender and race are considered as potential trajectory pairs.
Parameters
----------
data : pd.DataFrame
Gene expression data with patients as columns.
metadata_selection : pd.DataFrame
Clinical metadata with columns: histological_type, race, gender, stage.
distance : {'wasserstein', 'euclidean'}, default='wasserstein'
Distance metric to use for calculating patient-to-patient distances.
early_late : bool, default=False
If True, uses early/late stage groupings. If False, uses I-IV stages.
negative_trajectory : bool, default=False
If True, generates same-stage transitions (negative controls).
If False, generates progression transitions (positive trajectories).
Returns
-------
pd.DataFrame
DataFrame containing all possible transitions with columns:
- source, target: Patient IDs
- source_gender, target_gender: Gender
- source_race, target_race: Race
- transition: Stage transition label (e.g., '1_to_2', 'early_to_late')
- distance: Calculated distance between patients
Sorted by gender, race, transition, and distance.
Raises
------
ValueError
If distance metric is not 'wasserstein' or 'euclidean'.
Notes
-----
- For positive trajectories: links I→II, II→III, III→IV or early→late
- For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
- Only patients with identical gender and race are paired
"""
# Select distance function
if distance == 'wasserstein':
from scipy.stats import wasserstein_distance
distance_fun = wasserstein_distance
elif distance == 'euclidean':
from scipy.spatial.distance import euclidean
distance_fun = euclidean
else:
raise ValueError('Distance function not implemented. Use either "wasserstein" or "euclidean".')
# Define stage transitions based on parameters
if early_late and not negative_trajectory:
possible_transitions = ['early_to_late']
stage_pairs = [['early', 'late']]
elif early_late and negative_trajectory:
possible_transitions = ['early_to_early', 'late_to_late']
stage_pairs = [['early', 'early'], ['late', 'late']]
elif not early_late and not negative_trajectory:
possible_transitions = ['1_to_2', '2_to_3', '3_to_4']
stage_pairs = [['I', 'II'], ['II', 'III'], ['III', 'IV']]
elif not early_late and negative_trajectory:
possible_transitions = ['1_to_1', '2_to_2', '3_to_3', '4_to_4']
stage_pairs = [['I', 'I'], ['II', 'II'], ['III', 'III'], ['IV', 'IV']]
# Calculate all possible transitions
results = []
for i_tr, transition in enumerate(possible_transitions):
source_target_stage = stage_pairs[i_tr]
# Iterate through all patient pairs at specified stages
for pat_i in metadata_selection.index[metadata_selection['stage'] == source_target_stage[0]]:
for pat_ii in metadata_selection.index[metadata_selection['stage'] == source_target_stage[1]]:
# Extract metadata for both patients
source_gender = metadata_selection.at[pat_i, 'gender']
target_gender = metadata_selection.at[pat_ii, 'gender']
source_race = metadata_selection.at[pat_i, 'race']
target_race = metadata_selection.at[pat_ii, 'race']
# Skip if metadata doesn't match (gender and race must match)
if not (source_race == target_race and source_gender == target_gender):
continue
# Store transition information
results_i = {
'source': pat_i,
'target': pat_ii,
'source_gender': source_gender,
'target_gender': target_gender,
'source_race': source_race,
'target_race': target_race,
'transition': transition,
'distance': distance_fun(data[pat_i], data[pat_ii]),
}
results.append(results_i)
# Convert to DataFrame and sort
results_df = pd.DataFrame(results)
results_df.sort_values(
['source_gender', 'target_gender', 'source_race', 'target_race',
'transition', 'distance'],
inplace=True,
ignore_index=True
)
return results_df
|
Example Usage:
from renalprog.modeling.predict import calculate_all_possible_transitions
# Calculate all transitions
transitions = calculate_all_possible_transitions(
latent_early=early_latent,
latent_late=late_latent,
patient_ids_early=early_ids,
patient_ids_late=late_ids,
output_dir=Path("data/processed/transitions")
)
# Analyze transition patterns
print(transitions.describe())
Dynamic Enrichment
dynamic_enrichment_analysis
Perform pathway enrichment at each trajectory timepoint.
dynamic_enrichment_analysis
dynamic_enrichment_analysis(
trajectory_dir: Path,
pathways_file: Path,
output_dir: Path,
cancer_type: str = "kirc",
) -> pd.DataFrame
Perform dynamic enrichment analysis on synthetic trajectories.
This orchestrates: 1. DESeq2 analysis on each trajectory point 2. GSEA on differential expression results 3. Aggregation of enrichment across trajectories
Args: trajectory_dir: Directory containing trajectory CSV files pathways_file: Path to pathway GMT file output_dir: Directory to save results cancer_type: Cancer type identifier
Returns: DataFrame with aggregated enrichment results
Source code in renalprog/modeling/predict.py
| def dynamic_enrichment_analysis(
trajectory_dir: Path,
pathways_file: Path,
output_dir: Path,
cancer_type: str = "kirc"
) -> pd.DataFrame:
"""
Perform dynamic enrichment analysis on synthetic trajectories.
This orchestrates:
1. DESeq2 analysis on each trajectory point
2. GSEA on differential expression results
3. Aggregation of enrichment across trajectories
Args:
trajectory_dir: Directory containing trajectory CSV files
pathways_file: Path to pathway GMT file
output_dir: Directory to save results
cancer_type: Cancer type identifier
Returns:
DataFrame with aggregated enrichment results
"""
logger.info(f"Running dynamic enrichment analysis for {cancer_type}")
# TODO: Implement orchestration
# Migrate from src_deseq_and_gsea_NCSR/full_bash.sh and related scripts
raise NotImplementedError(
"dynamic_enrichment_analysis() needs implementation. "
"Migrate orchestration from src_deseq_and_gsea_NCSR/full_bash.sh, "
"py_deseq.py, and trajectory_analysis.py"
)
|
Example Usage:
from renalprog.modeling.predict import dynamic_enrichment_analysis
from pathlib import Path
# Analyze pathway dynamics along trajectories
enrichment_results = dynamic_enrichment_analysis(
trajectories=trajectory_gene_expression, # Shape: (n_traj, n_steps, n_genes)
gene_names=gene_list,
pathway_file=Path("data/external/ReactomePathways.gmt"),
output_dir=Path("reports/dynamic_enrichment")
)
# Results contain enrichment at each timepoint
for timepoint, results in enrichment_results.items():
print(f"Timepoint {timepoint}: {len(results)} enriched pathways")
Interpolation Methods
interpolate_latent_linear
Linear interpolation between points.
interpolate_latent_linear
interpolate_latent_linear(
z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray
Linear interpolation in latent space.
Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps
Returns: Array of interpolated latent vectors (n_steps x latent_dim)
Source code in renalprog/modeling/predict.py
| def interpolate_latent_linear(
z_source: np.ndarray,
z_target: np.ndarray,
n_steps: int = 50
) -> np.ndarray:
"""
Linear interpolation in latent space.
Args:
z_source: Source latent vector
z_target: Target latent vector
n_steps: Number of interpolation steps
Returns:
Array of interpolated latent vectors (n_steps x latent_dim)
"""
alphas = np.linspace(0, 1, n_steps)
interpolated = np.array([
(1 - alpha) * z_source + alpha * z_target
for alpha in alphas
])
return interpolated
|
interpolate_latent_spherical
Spherical interpolation (SLERP) for normalized spaces.
interpolate_latent_spherical
interpolate_latent_spherical(
z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray
Spherical (SLERP) interpolation in latent space.
Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps
Returns: Array of interpolated latent vectors (n_steps x latent_dim)
Source code in renalprog/modeling/predict.py
| def interpolate_latent_spherical(
z_source: np.ndarray,
z_target: np.ndarray,
n_steps: int = 50
) -> np.ndarray:
"""
Spherical (SLERP) interpolation in latent space.
Args:
z_source: Source latent vector
z_target: Target latent vector
n_steps: Number of interpolation steps
Returns:
Array of interpolated latent vectors (n_steps x latent_dim)
"""
# Normalize vectors
z_source_norm = z_source / np.linalg.norm(z_source)
z_target_norm = z_target / np.linalg.norm(z_target)
# Calculate angle between vectors
omega = np.arccos(np.clip(np.dot(z_source_norm, z_target_norm), -1.0, 1.0))
if omega < 1e-8:
# Vectors are nearly identical, use linear interpolation
return interpolate_latent_linear(z_source, z_target, n_steps)
# SLERP formula
alphas = np.linspace(0, 1, n_steps)
interpolated = np.array([
(np.sin((1 - alpha) * omega) / np.sin(omega)) * z_source +
(np.sin(alpha * omega) / np.sin(omega)) * z_target
for alpha in alphas
])
return interpolated
|
Comparison:
from renalprog.modeling.predict import (
interpolate_latent_linear,
interpolate_latent_spherical
)
import numpy as np
z_start = np.random.randn(1, 128)
z_end = np.random.randn(1, 128)
# Linear interpolation
traj_linear = interpolate_latent_linear(z_start, z_end, n_steps=50)
# Spherical interpolation (preserves norm better)
traj_spherical = interpolate_latent_spherical(z_start, z_end, n_steps=50)
# Spherical is preferred for normalized latent spaces
Visualization
plot_trajectory
Visualize individual trajectory.
plot_trajectory
plot_trajectory(
trajectory: ndarray,
gene_names: Optional[List[str]] = None,
save_path: Optional[Path] = None,
title: str = "Gene Expression Trajectory",
n_genes_to_show: int = 20,
) -> go.Figure
Plot gene expression changes along a trajectory.
Args: trajectory: Array of shape (n_timepoints, n_genes) gene_names: Optional list of gene names save_path: Optional path to save figure title: Plot title n_genes_to_show: Number of top varying genes to display
Returns: Plotly Figure object
Source code in renalprog/plots.py
| def plot_trajectory(
trajectory: np.ndarray,
gene_names: Optional[List[str]] = None,
save_path: Optional[Path] = None,
title: str = "Gene Expression Trajectory",
n_genes_to_show: int = 20,
) -> go.Figure:
"""
Plot gene expression changes along a trajectory.
Args:
trajectory: Array of shape (n_timepoints, n_genes)
gene_names: Optional list of gene names
save_path: Optional path to save figure
title: Plot title
n_genes_to_show: Number of top varying genes to display
Returns:
Plotly Figure object
"""
n_timepoints, n_genes = trajectory.shape
# Calculate variance for each gene
gene_variance = np.var(trajectory, axis=0)
top_genes_idx = np.argsort(gene_variance)[-n_genes_to_show:]
if gene_names is None:
gene_names = [f'Gene_{i}' for i in range(n_genes)]
fig = go.Figure()
timepoints = list(range(n_timepoints))
for idx in top_genes_idx:
fig.add_trace(go.Scatter(
x=timepoints,
y=trajectory[:, idx],
mode='lines',
name=gene_names[idx],
line=dict(width=1.5)
))
fig.update_layout(
title=title,
xaxis_title='Timepoint',
yaxis_title='Expression Level',
template=DEFAULT_TEMPLATE,
width=DEFAULT_WIDTH,
height=DEFAULT_HEIGHT,
hovermode='x unified'
)
if save_path:
save_plot(fig, save_path)
return fig
|
Example:
from renalprog.plots import plot_trajectory
from pathlib import Path
# Plot single trajectory
plot_trajectory(
trajectory=trajectory_data[0], # Shape: (n_steps, n_features)
feature_names=selected_genes,
output_path=Path("reports/figures/trajectory_example.png"),
title="Disease Progression Trajectory"
)
Complete Workflow Example
import torch
import pandas as pd
import numpy as np
from pathlib import Path
from renalprog.modeling.train import VAE
from renalprog.modeling.predict import (
apply_vae,
create_patient_connections,
generate_trajectories,
build_trajectory_network,
dynamic_enrichment_analysis
)
from renalprog.plots import plot_trajectory, plot_latent_space
# 1. Load model and data
model = VAE(input_dim=20000, mid_dim=1024, features=128)
model.load_state_dict(torch.load("models/my_vae/best_model.pt"))
expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
clinical = pd.read_csv("data/interim/split/test_clinical.tsv", sep="\t", index_col=0)
# 2. Encode to latent space
results = apply_vae(model, expr.values, device='cuda')
latent = results['latent']
# 3. Split by stage
early_mask = clinical['stage'] == 'early'
late_mask = clinical['stage'] == 'late'
# 4. Create patient connections
connections = create_patient_connections(
latent_early=latent[early_mask],
latent_late=latent[late_mask],
method='closest',
output_path=Path("data/processed/connections.csv")
)
# 5. Generate trajectories
trajectories = generate_trajectories(
model=model,
start_data=expr.values[early_mask],
end_data=expr.values[late_mask],
n_steps=50,
interpolation='spherical',
device='cuda'
)
# 6. Build trajectory network
network = build_trajectory_network(
connections=connections,
output_path=Path("data/processed/network.graphml")
)
# 7. Dynamic enrichment analysis
enrichment = dynamic_enrichment_analysis(
trajectories=trajectories,
gene_names=expr.columns.tolist(),
pathway_file=Path("data/external/ReactomePathways.gmt"),
output_dir=Path("reports/enrichment")
)
# 8. Visualize
plot_trajectory(
trajectory=trajectories[0],
feature_names=expr.columns[:20], # Top 20 genes
output_path=Path("reports/figures/trajectory_001.png")
)
print(f"Generated {len(trajectories)} trajectories")
print(f"Network edges: {network.number_of_edges()}")
print(f"Enrichment timepoints: {len(enrichment)}")
See Also