Skip to content

Trajectories API

Functions for analyzing disease progression trajectories.

Overview

The trajectories module provides analysis tools for:

  • Trajectory network construction
  • Patient connectivity analysis
  • Temporal pathway enrichment
  • Transition probability calculation
  • Trajectory visualization

Trajectory Generation

generate_trajectories

Generate smooth disease progression trajectories.

generate_trajectories

generate_trajectories(
    model: Module,
    source_samples: DataFrame,
    target_samples: DataFrame,
    n_steps: int = 50,
    method: str = "linear",
    output_dir: Optional[Path] = None,
    parallel: bool = False,
    n_workers: Optional[int] = None,
) -> Dict[str, pd.DataFrame]

Generate synthetic trajectories between source and target patient samples.

This function creates interpolated gene expression profiles in the latent space between pairs of patients at different cancer stages.

Args: model: Trained VAE model source_samples: Source patient samples (early stage) target_samples: Target patient samples (late stage) n_steps: Number of interpolation steps method: Interpolation method ("linear" or "spherical") output_dir: Optional directory to save trajectories parallel: Whether to use parallel processing n_workers: Number of parallel workers (None = use all CPUs)

Returns: Dictionary mapping patient pairs to trajectory DataFrames

Source code in renalprog/modeling/predict.py
def generate_trajectories(
    model: torch.nn.Module,
    source_samples: pd.DataFrame,
    target_samples: pd.DataFrame,
    n_steps: int = 50,
    method: str = "linear",
    output_dir: Optional[Path] = None,
    parallel: bool = False,
    n_workers: Optional[int] = None
) -> Dict[str, pd.DataFrame]:
    """
    Generate synthetic trajectories between source and target patient samples.

    This function creates interpolated gene expression profiles in the latent
    space between pairs of patients at different cancer stages.

    Args:
        model: Trained VAE model
        source_samples: Source patient samples (early stage)
        target_samples: Target patient samples (late stage)
        n_steps: Number of interpolation steps
        method: Interpolation method ("linear" or "spherical")
        output_dir: Optional directory to save trajectories
        parallel: Whether to use parallel processing
        n_workers: Number of parallel workers (None = use all CPUs)

    Returns:
        Dictionary mapping patient pairs to trajectory DataFrames
    """
    logger.info(f"Generating trajectories with {n_steps} steps using {method} interpolation")

    # TODO: Implement trajectory generation
    # Migrate from src_deseq_and_gsea_NCSR/synthetic_data_generation.py

    raise NotImplementedError(
        "generate_trajectories() needs implementation from "
        "src_deseq_and_gsea_NCSR/synthetic_data_generation.py and "
        "src/data/fun_interpol.py"
    )

Network Construction

build_trajectory_network

Build directed graph of patient transitions.

build_trajectory_network

build_trajectory_network(
    patient_links: DataFrame,
) -> Tuple[Dict[str, List[str]], List[List[str]]]

Build trajectory network and find all complete disease progression paths.

Constructs a directed graph from patient links and identifies all possible complete trajectories from root nodes (earliest stage patients not appearing as targets) to leaf nodes (latest stage patients not appearing as sources).

Args: patient_links: DataFrame with 'source' and 'target' columns from linking functions

Returns: Tuple of: - network: Dict mapping each source patient to list of target patients - trajectories: List of complete trajectories, where each trajectory is a list of patient IDs ordered from earliest to latest stage

Network Structure: - Adjacency list representation: {source: [target1, target2, ...]} - Directed edges from earlier to later stages - Allows multiple outgoing edges (one patient → multiple next-stage patients)

Trajectory Discovery: - Uses depth-first search from root nodes - Root nodes: Patients in 'source' but not in 'target' (stage I or early) - Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late) - Each trajectory represents a complete disease progression path

Example: >>> network, trajectories = build_trajectory_network(patient_links) >>> print(f"Network has {len(network)} nodes") >>> print(f"Found {len(trajectories)} complete trajectories") >>> print(f"Example trajectory: {trajectories[0]}") Network has 500 nodes Found 234 complete trajectories Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']

Trajectory Characteristics: - Length varies based on how many stages the path spans - Typical lengths: 2-4 patients for I→II→III→IV progressions - Length 2 for early→late progressions - Patients can appear in multiple trajectories

Note: - Cycles are prevented during trajectory search - All paths from root to leaf are enumerated - Trajectories respect chronological disease progression

Source code in renalprog/modeling/predict.py
def build_trajectory_network(
    patient_links: pd.DataFrame
) -> Tuple[Dict[str, List[str]], List[List[str]]]:
    """
    Build trajectory network and find all complete disease progression paths.

    Constructs a directed graph from patient links and identifies all possible
    complete trajectories from root nodes (earliest stage patients not appearing
    as targets) to leaf nodes (latest stage patients not appearing as sources).

    Args:
        patient_links: DataFrame with 'source' and 'target' columns from linking functions

    Returns:
        Tuple of:
        - network: Dict mapping each source patient to list of target patients
        - trajectories: List of complete trajectories, where each trajectory is a
                        list of patient IDs ordered from earliest to latest stage

    Network Structure:
        - Adjacency list representation: {source: [target1, target2, ...]}
        - Directed edges from earlier to later stages
        - Allows multiple outgoing edges (one patient → multiple next-stage patients)

    Trajectory Discovery:
        - Uses depth-first search from root nodes
        - Root nodes: Patients in 'source' but not in 'target' (stage I or early)
        - Leaf nodes: Patients in 'target' but not in 'source' (stage IV or late)
        - Each trajectory represents a complete disease progression path

    Example:
        >>> network, trajectories = build_trajectory_network(patient_links)
        >>> print(f"Network has {len(network)} nodes")
        >>> print(f"Found {len(trajectories)} complete trajectories")
        >>> print(f"Example trajectory: {trajectories[0]}")
        Network has 500 nodes
        Found 234 complete trajectories
        Example trajectory: ['PAT001', 'PAT045', 'PAT123', 'PAT289']

    Trajectory Characteristics:
        - Length varies based on how many stages the path spans
        - Typical lengths: 2-4 patients for I→II→III→IV progressions
        - Length 2 for early→late progressions
        - Patients can appear in multiple trajectories

    Note:
        - Cycles are prevented during trajectory search
        - All paths from root to leaf are enumerated
        - Trajectories respect chronological disease progression
    """
    logger.info("Building trajectory network from patient links")

    sources = patient_links['source']
    targets = patient_links['target']

    # Build network adjacency list
    network = {}
    for source, target in zip(sources, targets):
        if source not in network:
            network[source] = []
        network[source].append(target)

    logger.info(f"Network built: {len(network)} source nodes")

    # Find root nodes (patients who are sources but never targets)
    unique_sources = set(sources) - set(targets)
    logger.info(f"Found {len(unique_sources)} root nodes (earliest stage patients)")

    # Recursively find all trajectories from each root
    def find_trajectories(start_node: str, visited: Optional[List[str]] = None) -> List[List[str]]:
        """Depth-first search to find all paths from start_node to leaf nodes."""
        if visited is None:
            visited = []

        visited.append(start_node)

        # If node has no outgoing edges, this is a leaf node - return path
        if start_node not in network:
            return [visited]

        # Recursively explore all targets
        trajectories = []
        for target in network[start_node]:
            if target not in visited:  # Avoid cycles
                new_visited = visited.copy()
                trajectories.extend(find_trajectories(target, new_visited))

        return trajectories

    # Find all trajectories starting from each root
    all_trajectories = []

    if len(unique_sources) == 0:
        # No clear root nodes - this happens with early→late transitions where
        # patients can be both sources and targets. In this case, each source→target
        # pair is already a complete 2-patient trajectory.
        logger.info("No root nodes found (typical for early→late transitions).")
        logger.info("Using each source→target pair as a complete trajectory.")
        for source, target in zip(sources, targets):
            all_trajectories.append([source, target])
    else:
        # Standard case: multi-stage progressions (I→II→III→IV)
        for source in unique_sources:
            all_trajectories.extend(find_trajectories(source))

    logger.info(f"Discovered {len(all_trajectories)} complete disease progression trajectories")

    # Log trajectory length statistics only if we have trajectories
    if len(all_trajectories) > 0:
        traj_lengths = [len(t) for t in all_trajectories]
        logger.info(f"Trajectory lengths - Min: {min(traj_lengths)}, Max: {max(traj_lengths)}, "
                    f"Mean: {np.mean(traj_lengths):.1f}")
    else:
        logger.warning("No trajectories found!")

    return network, all_trajectories

Example Usage:

from renalprog.modeling.predict import build_trajectory_network
import pandas as pd
from pathlib import Path

# Load patient connections
connections = pd.read_csv("data/processed/patient_connections.csv")

# Build network
network = build_trajectory_network(
    connections=connections,
    output_path=Path("data/processed/trajectory_network.graphml")
)

print(f"Network has {network.number_of_nodes()} nodes")
print(f"Network has {network.number_of_edges()} edges")

generate_trajectory_data

Generate complete trajectory dataset with metadata.

generate_trajectory_data

generate_trajectory_data(
    vae_model: Module,
    recnet_model: Optional[Module],
    trajectory: List[str],
    gene_data: DataFrame,
    n_timepoints: int = 50,
    interpolation_method: str = "linear",
    device: str = "cpu",
    save_path: Optional[Path] = None,
    scaler: Optional[MinMaxScaler] = None,
) -> pd.DataFrame

Generate synthetic gene expression data along a patient trajectory.

Creates N interpolated time points between consecutive patients in a trajectory by performing interpolation in the VAE latent space, then decoding back to gene expression space. Optionally applies reconstruction network for refinement.

Args: vae_model: Trained VAE model for encoding/decoding recnet_model: Optional reconstruction network for refining VAE output trajectory: List of patient IDs in chronological progression order gene_data: Gene expression DataFrame (genes × patients) n_timepoints: Number of interpolation points between each patient pair interpolation_method: 'linear' or 'spherical' interpolation in latent space device: Torch device for computation save_path: Optional path to save trajectory CSV file scaler: Pre-fitted MinMaxScaler from VAE training. If None, will fit on gene_data.

Returns: DataFrame with synthetic gene expression profiles for all time points. Shape: (n_timepoints * (len(trajectory)-1), n_genes) Index contains time point identifiers

Workflow: 1. Extract gene expression for each patient in trajectory 2. Normalize using the SAME scaler used during VAE training 3. Encode each patient to VAE latent space 4. For each consecutive pair: a. Interpolate in latent space (linear or spherical) b. Decode interpolated points back to gene space c. Optionally apply reconstruction network 5. Concatenate all segments into complete trajectory

Interpolation Methods: linear: Straight-line interpolation in latent space z(t) = (1-t)*z_source + t*z_target

spherical: Spherical linear interpolation (SLERP)
           Preserves magnitude, interpolates on hypersphere
           Recommended for normalized latent spaces

Note: CRITICAL: The scaler must be the same one used during VAE training. Using a different scaler will produce incorrect latent representations. If scaler=None, will fit on all gene_data (all patients), which approximates the training distribution.

Source code in renalprog/modeling/predict.py
def generate_trajectory_data(
    vae_model: torch.nn.Module,
    recnet_model: Optional[torch.nn.Module],
    trajectory: List[str],
    gene_data: pd.DataFrame,
    n_timepoints: int = 50,
    interpolation_method: str = 'linear',
    device: str = 'cpu',
    save_path: Optional[Path] = None,
    scaler: Optional[MinMaxScaler] = None
) -> pd.DataFrame:
    """
    Generate synthetic gene expression data along a patient trajectory.

    Creates N interpolated time points between consecutive patients in a trajectory
    by performing interpolation in the VAE latent space, then decoding back to
    gene expression space. Optionally applies reconstruction network for refinement.

    Args:
        vae_model: Trained VAE model for encoding/decoding
        recnet_model: Optional reconstruction network for refining VAE output
        trajectory: List of patient IDs in chronological progression order
        gene_data: Gene expression DataFrame (genes × patients)
        n_timepoints: Number of interpolation points between each patient pair
        interpolation_method: 'linear' or 'spherical' interpolation in latent space
        device: Torch device for computation
        save_path: Optional path to save trajectory CSV file
        scaler: Pre-fitted MinMaxScaler from VAE training. If None, will fit on gene_data.

    Returns:
        DataFrame with synthetic gene expression profiles for all time points.
        Shape: (n_timepoints * (len(trajectory)-1), n_genes)
        Index contains time point identifiers

    Workflow:
        1. Extract gene expression for each patient in trajectory
        2. Normalize using the SAME scaler used during VAE training
        3. Encode each patient to VAE latent space
        4. For each consecutive pair:
           a. Interpolate in latent space (linear or spherical)
           b. Decode interpolated points back to gene space
           c. Optionally apply reconstruction network
        5. Concatenate all segments into complete trajectory

    Interpolation Methods:
        linear: Straight-line interpolation in latent space
                z(t) = (1-t)*z_source + t*z_target

        spherical: Spherical linear interpolation (SLERP)
                   Preserves magnitude, interpolates on hypersphere
                   Recommended for normalized latent spaces

    Note:
        CRITICAL: The scaler must be the same one used during VAE training.
        Using a different scaler will produce incorrect latent representations.
        If scaler=None, will fit on all gene_data (all patients), which approximates
        the training distribution.
    """

    logger.info(f"Generating trajectory data for {len(trajectory)} patients")
    logger.info(f"Interpolation: {n_timepoints} points × {len(trajectory)-1} segments")
    logger.info(f"Method: {interpolation_method}")

    # Set models to evaluation mode
    vae_model.eval()
    if recnet_model is not None:
        recnet_model.eval()

    vae_model = vae_model.to(device)
    if recnet_model is not None:
        recnet_model = recnet_model.to(device)

    # Use provided scaler or fit new one on all gene data
    if scaler is None:
        logger.warning("No scaler provided - fitting new scaler on all gene data")
        logger.warning("This may not match VAE training normalization!")
        scaler = MinMaxScaler()
        # gene_data is (genes × patients), need (patients × genes) for scaler
        scaler.fit(gene_data.T.values)
        logger.info(f"Fitted scaler on {gene_data.shape[1]} patients")
    else:
        logger.info("Using provided scaler from VAE training")

    # Select interpolation function
    if interpolation_method == 'linear':
        interp_func = interpolate_latent_linear
    elif interpolation_method == 'spherical':
        interp_func = interpolate_latent_spherical
    else:
        raise ValueError(f"Unknown interpolation method: {interpolation_method}")

    # Generate synthetic data for each segment of the trajectory
    all_segments = []

    with torch.no_grad():
        for i in range(len(trajectory) - 1):
            source_patient = trajectory[i]
            target_patient = trajectory[i + 1]

            logger.info(f"Segment {i+1}/{len(trajectory)-1}: {source_patient}{target_patient}")

            # Get gene expression for source and target
            # gene_data is (genes × patients), so gene_data[patient] is a Series of gene values
            source_expr = gene_data[source_patient].values.reshape(1, -1)  # (1, genes)
            target_expr = gene_data[target_patient].values.reshape(1, -1)  # (1, genes)

            # Normalize data using the provided scaler
            # Scaler expects (n_samples, n_features) = (1, genes)
            source_norm = scaler.transform(source_expr)  # (1, genes)
            target_norm = scaler.transform(target_expr)  # (1, genes)

            # Encode to latent space
            source_tensor = torch.tensor(source_norm, dtype=torch.float32).to(device)
            target_tensor = torch.tensor(target_norm, dtype=torch.float32).to(device)

            _, _, _, z_source = vae_model(source_tensor)
            _, _, _, z_target = vae_model(target_tensor)

            # Interpolate in latent space
            z_source_np = z_source.cpu().numpy().flatten()
            z_target_np = z_target.cpu().numpy().flatten()

            interpolated_z = interp_func(z_source_np, z_target_np, n_timepoints)

            # Decode interpolated latent vectors
            interpolated_z_tensor = torch.tensor(interpolated_z, dtype=torch.float32).to(device)
            decoded = vae_model.decoder(interpolated_z_tensor)

            # Denormalize using the same scaler
            # decoded is (n_timepoints, genes), scaler expects (n_samples, n_features)
            decoded_np = decoded.cpu().numpy()  # (n_timepoints, genes)
            segment_data = scaler.inverse_transform(decoded_np)  # (n_timepoints, genes) - REAL SPACE

            # Apply reconstruction network if provided
            # CRITICAL: RecNet works on REAL SPACE data, not normalized!
            if recnet_model is not None:
                # Convert to tensor and apply RecNet directly to real space data
                segment_tensor = torch.tensor(segment_data, dtype=torch.float32).to(device)
                refined = recnet_model(segment_tensor)
                segment_data = refined.cpu().numpy()  # (n_timepoints, genes) - REAL SPACE

            all_segments.append(segment_data)

    # Concatenate all segments
    trajectory_data = np.vstack(all_segments)

    # Create DataFrame
    trajectory_df = pd.DataFrame(
        trajectory_data,
        columns=gene_data.index
    )

    # Create informative index
    time_indices = []
    for i in range(len(trajectory) - 1):
        for t in range(n_timepoints):
            time_indices.append(f"{trajectory[i]}_to_{trajectory[i+1]}_t{t:03d}")
    trajectory_df.index = time_indices

    logger.info(f"Generated trajectory data: {trajectory_df.shape}")

    # Save if path provided
    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        trajectory_df.to_csv(save_path)
        logger.info(f"Saved trajectory to: {save_path}")

    return trajectory_df

Patient Connectivity

create_patient_connections

Create optimal patient pairings for trajectories.

create_patient_connections

create_patient_connections(
    data: DataFrame,
    clinical: Series,
    method: str = "random",
    transition_type: str = "early_to_late",
    n_connections: Optional[int] = None,
    seed: int = 2023,
) -> pd.DataFrame

Create connections between patients for trajectory generation.

Args: data: Gene expression data clinical: Clinical stage information method: Method for creating connections ("random", "nearest", "all") transition_type: Type of transition ("early_to_late", "early_to_early", etc.) n_connections: Number of connections to create (None = all possible) seed: Random seed

Returns: DataFrame with columns: source, target, transition

Source code in renalprog/modeling/predict.py
def create_patient_connections(
    data: pd.DataFrame,
    clinical: pd.Series,
    method: str = "random",
    transition_type: str = "early_to_late",
    n_connections: Optional[int] = None,
    seed: int = 2023
) -> pd.DataFrame:
    """
    Create connections between patients for trajectory generation.

    Args:
        data: Gene expression data
        clinical: Clinical stage information
        method: Method for creating connections ("random", "nearest", "all")
        transition_type: Type of transition ("early_to_late", "early_to_early", etc.)
        n_connections: Number of connections to create (None = all possible)
        seed: Random seed

    Returns:
        DataFrame with columns: source, target, transition
    """
    logger.info(f"Creating patient connections: {transition_type} using {method} method")

    # TODO: Implement connection logic
    # Migrate from notebooks/4_1_trajectories.ipynb

    raise NotImplementedError(
        "create_patient_connections() needs implementation from "
        "notebooks/4_1_trajectories.ipynb"
    )

Link patients using closest latent space neighbors.

link_patients_closest(
    transitions_df: DataFrame,
    start_with_first_stage: bool = True,
    early_late: bool = False,
    closest: bool = True,
) -> pd.DataFrame

Link patients by selecting closest (or farthest) matches across stages.

For each patient at a source stage, this function identifies the closest (or farthest) patient at the target stage, considering metadata constraints (gender, race). This creates one-to-one patient linkages that form the basis for trajectory construction.

Args: transitions_df: DataFrame from calculate_all_possible_transitions() containing all possible patient pairs with distances start_with_first_stage: If True, build forward trajectories (early→late) If False, build backward trajectories (late→early) early_late: If True, uses early/late groupings. If False, uses I-IV stages closest: If True, connect closest patients. If False, connect farthest patients

Returns: DataFrame with selected patient links, containing one row per source patient with their optimal target patient match. Includes all columns from transitions_df.

Selection Strategy: - Forward (start_with_first_stage=True): For each source, find optimal target - Backward (start_with_first_stage=False): For each target, find optimal source - Closest (closest=True): Minimum distance match - Farthest (closest=False): Maximum distance match

Metadata Stratification: Links are selected independently within each combination of: - Gender (MALE, FEMALE) - Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE) This ensures demographic consistency in trajectories.

Example: >>> links = link_patients_closest( ... transitions_df=all_transitions, ... start_with_first_stage=True, ... closest=True ... ) >>> print(f"Created {len(links)} patient links") Created 234 patient links

Note: - Processes transitions in order for forward: I→II→III→IV - Processes in reverse for backward: IV→III→II→I - Each patient appears at most once as a source in the result

Source code in renalprog/modeling/predict.py
def link_patients_closest(
    transitions_df: pd.DataFrame,
    start_with_first_stage: bool = True,
    early_late: bool = False,
    closest: bool = True
) -> pd.DataFrame:
    """
    Link patients by selecting closest (or farthest) matches across stages.

    For each patient at a source stage, this function identifies the closest
    (or farthest) patient at the target stage, considering metadata constraints
    (gender, race). This creates one-to-one patient linkages that form the basis
    for trajectory construction.

    Args:
        transitions_df: DataFrame from calculate_all_possible_transitions()
                        containing all possible patient pairs with distances
        start_with_first_stage: If True, build forward trajectories (early→late)
                                If False, build backward trajectories (late→early)
        early_late: If True, uses early/late groupings. If False, uses I-IV stages
        closest: If True, connect closest patients. If False, connect farthest patients

    Returns:
        DataFrame with selected patient links, containing one row per source patient
        with their optimal target patient match. Includes all columns from transitions_df.

    Selection Strategy:
        - Forward (start_with_first_stage=True): For each source, find optimal target
        - Backward (start_with_first_stage=False): For each target, find optimal source
        - Closest (closest=True): Minimum distance match
        - Farthest (closest=False): Maximum distance match

    Metadata Stratification:
        Links are selected independently within each combination of:
        - Gender (MALE, FEMALE)
        - Race (ASIAN, BLACK OR AFRICAN AMERICAN, WHITE)
        This ensures demographic consistency in trajectories.

    Example:
        >>> links = link_patients_closest(
        ...     transitions_df=all_transitions,
        ...     start_with_first_stage=True,
        ...     closest=True
        ... )
        >>> print(f"Created {len(links)} patient links")
        Created 234 patient links

    Note:
        - Processes transitions in order for forward: I→II→III→IV
        - Processes in reverse for backward: IV→III→II→I
        - Each patient appears at most once as a source in the result
    """
    logger.info("Linking patients by closest/farthest matches")
    logger.info(f"Direction: {'Forward' if start_with_first_stage else 'Backward'}")
    logger.info(f"Strategy: {'Closest' if closest else 'Farthest'}")

    # Define transition order based on direction
    if start_with_first_stage and not early_late:
        transitions_possible = ['1_to_2', '2_to_3', '3_to_4']
    elif not start_with_first_stage and not early_late:
        transitions_possible = ['3_to_4', '2_to_3', '1_to_2']
    elif early_late:
        transitions_possible = ['early_to_late']

    # 0 for closest (smallest distance), -1 for farthest (largest distance)
    idx = 0 if closest else -1

    # Find closest/farthest patient for each source patient
    closest_list = []
    for transition_i in transitions_possible:
        transition_df_i = transitions_df[transitions_df['transition'] == transition_i]

        logger.info(f"Processing transition {transition_i}: {len(transition_df_i)} pairs")

        # Iterate through all metadata combinations
        for gender_i in ['FEMALE', 'MALE']:
            df_gender_i = transition_df_i.query(f"source_gender == '{gender_i}'")

            for race_i in ['ASIAN', 'BLACK OR AFRICAN AMERICAN', 'WHITE']:
                df_race_i = df_gender_i.query(f"source_race == '{race_i}'")

                if df_race_i.empty:
                    continue

                # Get unique patients to link
                unique_sources = df_race_i['source'].unique()
                unique_targets = df_race_i['target'].unique()
                use_uniques = unique_sources if start_with_first_stage else unique_targets
                use_column = 'source' if start_with_first_stage else 'target'

                # Find closest/farthest match for each patient
                for pat_i in use_uniques:
                    pat_matches = df_race_i[df_race_i[use_column] == pat_i]
                    if len(pat_matches) > 0:
                        # Sort by distance and select first (closest) or last (farthest)
                        best_match = pat_matches.sort_values('distance').iloc[idx]
                        closest_list.append(best_match)

    # Convert to DataFrame
    closest_df = pd.DataFrame(closest_list)
    closest_df.reset_index(drop=True, inplace=True)

    logger.info(f"Created {len(closest_df)} patient links")

    return closest_df

Example:

from renalprog.modeling.predict import link_patients_closest
import numpy as np

early_latent = np.random.randn(100, 128)
late_latent = np.random.randn(80, 128)

connections = link_patients_closest(
    latent_early=early_latent,
    latent_late=late_latent,
    patient_ids_early=['E001', 'E002', ...],
    patient_ids_late=['L001', 'L002', ...]
)

# Returns DataFrame with columns: early_patient, late_patient, distance

Link patients randomly (control method).

link_patients_random(
    results_df: DataFrame,
    start_with_first_stage: bool = True,
    link_next: int = 5,
    transitions_possible: Optional[List[str]] = None,
) -> pd.DataFrame

Link patients to multiple random targets at the next stage.

Instead of linking each patient to only their closest match, this function randomly samples multiple patients at the next stage to link to each source patient. This creates a one-to-many mapping useful for generating multiple trajectory samples.

Parameters:

Name Type Description Default
results_df DataFrame

DataFrame with possible sources and targets, their metadata, and distance.

required
start_with_first_stage bool

If True, initiate trajectories with first stage as sources. If False, initiate trajectories with last stage as sources.

True
link_next int

Number of patients at next stage to randomly link to each patient of current stage.

5
transitions_possible list

List of transitions to process (e.g., ['1_to_2', '2_to_3']). If None, defaults to ['early_to_late'].

None

Returns:

Type Description
DataFrame

DataFrame with randomly sampled patient links for each transition. Contains multiple rows per source patient (up to link_next).

Notes
  • Random sampling is primarily performed for WHITE race patients due to sample size
  • If fewer than link_next targets are available, all available targets are selected
  • Patients from other races are included with all their possible connections
  • Empty DataFrame is returned if no WHITE patients are found
Source code in renalprog/modeling/predict.py
def link_patients_random(
    results_df: pd.DataFrame,
    start_with_first_stage: bool = True,
    link_next: int = 5,
    transitions_possible: Optional[List[str]] = None
) -> pd.DataFrame:
    """
    Link patients to multiple random targets at the next stage.

    Instead of linking each patient to only their closest match, this function randomly
    samples multiple patients at the next stage to link to each source patient. This
    creates a one-to-many mapping useful for generating multiple trajectory samples.

    Parameters
    ----------
    results_df : pd.DataFrame
        DataFrame with possible sources and targets, their metadata, and distance.
    start_with_first_stage : bool, default=True
        If True, initiate trajectories with first stage as sources.
        If False, initiate trajectories with last stage as sources.
    link_next : int, default=5
        Number of patients at next stage to randomly link to each patient of current stage.
    transitions_possible : list, optional
        List of transitions to process (e.g., ['1_to_2', '2_to_3']).
        If None, defaults to ['early_to_late'].

    Returns
    -------
    pd.DataFrame
        DataFrame with randomly sampled patient links for each transition.
        Contains multiple rows per source patient (up to link_next).

    Notes
    -----
    - Random sampling is primarily performed for WHITE race patients due to sample size
    - If fewer than link_next targets are available, all available targets are selected
    - Patients from other races are included with all their possible connections
    - Empty DataFrame is returned if no WHITE patients are found
    """
    # Set default transitions if not provided
    if transitions_possible is None:
        transitions_possible = ['early_to_late']

    # Get unique genders and races
    unique_genders = results_df['source_gender'].unique().tolist()
    # Get unique races
    unique_races = results_df['source_race'].unique().tolist()
    if 'WHITE' in unique_races:
        unique_races.remove('WHITE')
    # transition:
    samples = []
    for transition_i in transitions_possible:
        transition_df_i = results_df[results_df['transition'] == transition_i]
        for gender_i in unique_genders:
            df_samples_i = transition_df_i.query(
                f"source_gender == '{gender_i}' & source_race == 'WHITE'")  # we can only do this for the whites since these are the only ones with enough samples
            if df_samples_i.empty:
                print(f"Warning: No WHITE patients found for gender {gender_i} in transition {transition_i}")
                continue
            unique_sources_i = np.unique(df_samples_i['source']).tolist()
            unique_targets_i = np.unique(df_samples_i['target']).tolist()
            use_uniques = unique_sources_i if start_with_first_stage else unique_targets_i
            use_source_target = 'source' if start_with_first_stage else 'target'
            for pat_i in use_uniques:
                sample_i = df_samples_i.loc[df_samples_i[use_source_target] == pat_i]
                if len(sample_i) >= link_next:
                    sample_i = sample_i.sample(
                        link_next)  # Sample a number of patients at next stage to link to each patient of current stage
                else:
                    sample_i = sample_i.sample(len(sample_i))  # Sample all available patients if less than link_next
                samples.append(sample_i)

    # Check if samples list is empty
    if not samples:
        print("Warning: No samples found for WHITE race. Returning empty DataFrame.")
        return pd.DataFrame(columns=results_df.columns)

    # Turn samples into dataframe:
    samples_df = pd.concat(samples)
    # Add the rest of the races
    if unique_races:
        samples_df = pd.concat(
            [
                samples_df,
                results_df[results_df['source_race'].isin(unique_races)]
            ]
        )
    samples_df.reset_index(drop=True, inplace=True)
    return samples_df

Transition Analysis

calculate_all_possible_transitions

Calculate metrics for all possible patient transitions.

calculate_all_possible_transitions

calculate_all_possible_transitions(
    data: DataFrame,
    metadata_selection: DataFrame,
    distance: str = "wasserstein",
    early_late: bool = False,
    negative_trajectory: bool = False,
) -> pd.DataFrame

Calculate all possible patient-to-patient transitions for KIRC cancer.

This function computes pairwise distances between all patients at consecutive (or same) cancer stages, considering metadata constraints. Only patients with matching gender and race are considered as potential trajectory pairs.

Parameters:

Name Type Description Default
data DataFrame

Gene expression data with patients as columns.

required
metadata_selection DataFrame

Clinical metadata with columns: histological_type, race, gender, stage.

required
distance (wasserstein, euclidean)

Distance metric to use for calculating patient-to-patient distances.

'wasserstein'
early_late bool

If True, uses early/late stage groupings. If False, uses I-IV stages.

False
negative_trajectory bool

If True, generates same-stage transitions (negative controls). If False, generates progression transitions (positive trajectories).

False

Returns:

Type Description
DataFrame

DataFrame containing all possible transitions with columns: - source, target: Patient IDs - source_gender, target_gender: Gender - source_race, target_race: Race - transition: Stage transition label (e.g., '1_to_2', 'early_to_late') - distance: Calculated distance between patients

Sorted by gender, race, transition, and distance.

Raises:

Type Description
ValueError

If distance metric is not 'wasserstein' or 'euclidean'.

Notes
  • For positive trajectories: links I→II, II→III, III→IV or early→late
  • For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
  • Only patients with identical gender and race are paired
Source code in renalprog/modeling/predict.py
def calculate_all_possible_transitions(
    data: pd.DataFrame,
    metadata_selection: pd.DataFrame,
    distance: str = 'wasserstein',
    early_late: bool = False,
    negative_trajectory: bool = False
) -> pd.DataFrame:
    """
    Calculate all possible patient-to-patient transitions for KIRC cancer.

    This function computes pairwise distances between all patients at consecutive
    (or same) cancer stages, considering metadata constraints. Only patients with
    matching gender and race are considered as potential trajectory pairs.

    Parameters
    ----------
    data : pd.DataFrame
        Gene expression data with patients as columns.
    metadata_selection : pd.DataFrame
        Clinical metadata with columns: histological_type, race, gender, stage.
    distance : {'wasserstein', 'euclidean'}, default='wasserstein'
        Distance metric to use for calculating patient-to-patient distances.
    early_late : bool, default=False
        If True, uses early/late stage groupings. If False, uses I-IV stages.
    negative_trajectory : bool, default=False
        If True, generates same-stage transitions (negative controls).
        If False, generates progression transitions (positive trajectories).

    Returns
    -------
    pd.DataFrame
        DataFrame containing all possible transitions with columns:
        - source, target: Patient IDs
        - source_gender, target_gender: Gender
        - source_race, target_race: Race
        - transition: Stage transition label (e.g., '1_to_2', 'early_to_late')
        - distance: Calculated distance between patients

        Sorted by gender, race, transition, and distance.

    Raises
    ------
    ValueError
        If distance metric is not 'wasserstein' or 'euclidean'.

    Notes
    -----
    - For positive trajectories: links I→II, II→III, III→IV or early→late
    - For negative trajectories: links I→I, II→II, III→III, IV→IV or early→early, late→late
    - Only patients with identical gender and race are paired
    """
    # Select distance function
    if distance == 'wasserstein':
        from scipy.stats import wasserstein_distance
        distance_fun = wasserstein_distance
    elif distance == 'euclidean':
        from scipy.spatial.distance import euclidean
        distance_fun = euclidean
    else:
        raise ValueError('Distance function not implemented. Use either "wasserstein" or "euclidean".')

    # Define stage transitions based on parameters
    if early_late and not negative_trajectory:
        possible_transitions = ['early_to_late']
        stage_pairs = [['early', 'late']]
    elif early_late and negative_trajectory:
        possible_transitions = ['early_to_early', 'late_to_late']
        stage_pairs = [['early', 'early'], ['late', 'late']]
    elif not early_late and not negative_trajectory:
        possible_transitions = ['1_to_2', '2_to_3', '3_to_4']
        stage_pairs = [['I', 'II'], ['II', 'III'], ['III', 'IV']]
    elif not early_late and negative_trajectory:
        possible_transitions = ['1_to_1', '2_to_2', '3_to_3', '4_to_4']
        stage_pairs = [['I', 'I'], ['II', 'II'], ['III', 'III'], ['IV', 'IV']]

    # Calculate all possible transitions
    results = []
    for i_tr, transition in enumerate(possible_transitions):
        source_target_stage = stage_pairs[i_tr]

        # Iterate through all patient pairs at specified stages
        for pat_i in metadata_selection.index[metadata_selection['stage'] == source_target_stage[0]]:
            for pat_ii in metadata_selection.index[metadata_selection['stage'] == source_target_stage[1]]:
                # Extract metadata for both patients
                source_gender = metadata_selection.at[pat_i, 'gender']
                target_gender = metadata_selection.at[pat_ii, 'gender']
                source_race = metadata_selection.at[pat_i, 'race']
                target_race = metadata_selection.at[pat_ii, 'race']

                # Skip if metadata doesn't match (gender and race must match)
                if not (source_race == target_race and source_gender == target_gender):
                    continue

                # Store transition information
                results_i = {
                    'source': pat_i,
                    'target': pat_ii,
                    'source_gender': source_gender,
                    'target_gender': target_gender,
                    'source_race': source_race,
                    'target_race': target_race,
                    'transition': transition,
                    'distance': distance_fun(data[pat_i], data[pat_ii]),
                }
                results.append(results_i)

    # Convert to DataFrame and sort
    results_df = pd.DataFrame(results)
    results_df.sort_values(
        ['source_gender', 'target_gender', 'source_race', 'target_race',
         'transition', 'distance'],
        inplace=True,
        ignore_index=True
    )
    return results_df

Example Usage:

from renalprog.modeling.predict import calculate_all_possible_transitions

# Calculate all transitions
transitions = calculate_all_possible_transitions(
    latent_early=early_latent,
    latent_late=late_latent,
    patient_ids_early=early_ids,
    patient_ids_late=late_ids,
    output_dir=Path("data/processed/transitions")
)

# Analyze transition patterns
print(transitions.describe())

Dynamic Enrichment

dynamic_enrichment_analysis

Perform pathway enrichment at each trajectory timepoint.

dynamic_enrichment_analysis

dynamic_enrichment_analysis(
    trajectory_dir: Path,
    pathways_file: Path,
    output_dir: Path,
    cancer_type: str = "kirc",
) -> pd.DataFrame

Perform dynamic enrichment analysis on synthetic trajectories.

This orchestrates: 1. DESeq2 analysis on each trajectory point 2. GSEA on differential expression results 3. Aggregation of enrichment across trajectories

Args: trajectory_dir: Directory containing trajectory CSV files pathways_file: Path to pathway GMT file output_dir: Directory to save results cancer_type: Cancer type identifier

Returns: DataFrame with aggregated enrichment results

Source code in renalprog/modeling/predict.py
def dynamic_enrichment_analysis(
    trajectory_dir: Path,
    pathways_file: Path,
    output_dir: Path,
    cancer_type: str = "kirc"
) -> pd.DataFrame:
    """
    Perform dynamic enrichment analysis on synthetic trajectories.

    This orchestrates:
    1. DESeq2 analysis on each trajectory point
    2. GSEA on differential expression results
    3. Aggregation of enrichment across trajectories

    Args:
        trajectory_dir: Directory containing trajectory CSV files
        pathways_file: Path to pathway GMT file
        output_dir: Directory to save results
        cancer_type: Cancer type identifier

    Returns:
        DataFrame with aggregated enrichment results
    """
    logger.info(f"Running dynamic enrichment analysis for {cancer_type}")

    # TODO: Implement orchestration
    # Migrate from src_deseq_and_gsea_NCSR/full_bash.sh and related scripts

    raise NotImplementedError(
        "dynamic_enrichment_analysis() needs implementation. "
        "Migrate orchestration from src_deseq_and_gsea_NCSR/full_bash.sh, "
        "py_deseq.py, and trajectory_analysis.py"
    )

Example Usage:

from renalprog.modeling.predict import dynamic_enrichment_analysis
from pathlib import Path

# Analyze pathway dynamics along trajectories
enrichment_results = dynamic_enrichment_analysis(
    trajectories=trajectory_gene_expression,  # Shape: (n_traj, n_steps, n_genes)
    gene_names=gene_list,
    pathway_file=Path("data/external/ReactomePathways.gmt"),
    output_dir=Path("reports/dynamic_enrichment")
)

# Results contain enrichment at each timepoint
for timepoint, results in enrichment_results.items():
    print(f"Timepoint {timepoint}: {len(results)} enriched pathways")

Interpolation Methods

interpolate_latent_linear

Linear interpolation between points.

interpolate_latent_linear

interpolate_latent_linear(
    z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray

Linear interpolation in latent space.

Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps

Returns: Array of interpolated latent vectors (n_steps x latent_dim)

Source code in renalprog/modeling/predict.py
def interpolate_latent_linear(
    z_source: np.ndarray,
    z_target: np.ndarray,
    n_steps: int = 50
) -> np.ndarray:
    """
    Linear interpolation in latent space.

    Args:
        z_source: Source latent vector
        z_target: Target latent vector
        n_steps: Number of interpolation steps

    Returns:
        Array of interpolated latent vectors (n_steps x latent_dim)
    """
    alphas = np.linspace(0, 1, n_steps)
    interpolated = np.array([
        (1 - alpha) * z_source + alpha * z_target
        for alpha in alphas
    ])
    return interpolated

interpolate_latent_spherical

Spherical interpolation (SLERP) for normalized spaces.

interpolate_latent_spherical

interpolate_latent_spherical(
    z_source: ndarray, z_target: ndarray, n_steps: int = 50
) -> np.ndarray

Spherical (SLERP) interpolation in latent space.

Args: z_source: Source latent vector z_target: Target latent vector n_steps: Number of interpolation steps

Returns: Array of interpolated latent vectors (n_steps x latent_dim)

Source code in renalprog/modeling/predict.py
def interpolate_latent_spherical(
    z_source: np.ndarray,
    z_target: np.ndarray,
    n_steps: int = 50
) -> np.ndarray:
    """
    Spherical (SLERP) interpolation in latent space.

    Args:
        z_source: Source latent vector
        z_target: Target latent vector
        n_steps: Number of interpolation steps

    Returns:
        Array of interpolated latent vectors (n_steps x latent_dim)
    """
    # Normalize vectors
    z_source_norm = z_source / np.linalg.norm(z_source)
    z_target_norm = z_target / np.linalg.norm(z_target)

    # Calculate angle between vectors
    omega = np.arccos(np.clip(np.dot(z_source_norm, z_target_norm), -1.0, 1.0))

    if omega < 1e-8:
        # Vectors are nearly identical, use linear interpolation
        return interpolate_latent_linear(z_source, z_target, n_steps)

    # SLERP formula
    alphas = np.linspace(0, 1, n_steps)
    interpolated = np.array([
        (np.sin((1 - alpha) * omega) / np.sin(omega)) * z_source +
        (np.sin(alpha * omega) / np.sin(omega)) * z_target
        for alpha in alphas
    ])

    return interpolated

Comparison:

from renalprog.modeling.predict import (
    interpolate_latent_linear,
    interpolate_latent_spherical
)
import numpy as np

z_start = np.random.randn(1, 128)
z_end = np.random.randn(1, 128)

# Linear interpolation
traj_linear = interpolate_latent_linear(z_start, z_end, n_steps=50)

# Spherical interpolation (preserves norm better)
traj_spherical = interpolate_latent_spherical(z_start, z_end, n_steps=50)

# Spherical is preferred for normalized latent spaces

Visualization

plot_trajectory

Visualize individual trajectory.

plot_trajectory

plot_trajectory(
    trajectory: ndarray,
    gene_names: Optional[List[str]] = None,
    save_path: Optional[Path] = None,
    title: str = "Gene Expression Trajectory",
    n_genes_to_show: int = 20,
) -> go.Figure

Plot gene expression changes along a trajectory.

Args: trajectory: Array of shape (n_timepoints, n_genes) gene_names: Optional list of gene names save_path: Optional path to save figure title: Plot title n_genes_to_show: Number of top varying genes to display

Returns: Plotly Figure object

Source code in renalprog/plots.py
def plot_trajectory(
    trajectory: np.ndarray,
    gene_names: Optional[List[str]] = None,
    save_path: Optional[Path] = None,
    title: str = "Gene Expression Trajectory",
    n_genes_to_show: int = 20,
) -> go.Figure:
    """
    Plot gene expression changes along a trajectory.

    Args:
        trajectory: Array of shape (n_timepoints, n_genes)
        gene_names: Optional list of gene names
        save_path: Optional path to save figure
        title: Plot title
        n_genes_to_show: Number of top varying genes to display

    Returns:
        Plotly Figure object
    """
    n_timepoints, n_genes = trajectory.shape

    # Calculate variance for each gene
    gene_variance = np.var(trajectory, axis=0)
    top_genes_idx = np.argsort(gene_variance)[-n_genes_to_show:]

    if gene_names is None:
        gene_names = [f'Gene_{i}' for i in range(n_genes)]

    fig = go.Figure()

    timepoints = list(range(n_timepoints))

    for idx in top_genes_idx:
        fig.add_trace(go.Scatter(
            x=timepoints,
            y=trajectory[:, idx],
            mode='lines',
            name=gene_names[idx],
            line=dict(width=1.5)
        ))

    fig.update_layout(
        title=title,
        xaxis_title='Timepoint',
        yaxis_title='Expression Level',
        template=DEFAULT_TEMPLATE,
        width=DEFAULT_WIDTH,
        height=DEFAULT_HEIGHT,
        hovermode='x unified'
    )

    if save_path:
        save_plot(fig, save_path)

    return fig

Example:

from renalprog.plots import plot_trajectory
from pathlib import Path

# Plot single trajectory
plot_trajectory(
    trajectory=trajectory_data[0],  # Shape: (n_steps, n_features)
    feature_names=selected_genes,
    output_path=Path("reports/figures/trajectory_example.png"),
    title="Disease Progression Trajectory"
)

Complete Workflow Example

import torch
import pandas as pd
import numpy as np
from pathlib import Path
from renalprog.modeling.train import VAE
from renalprog.modeling.predict import (
    apply_vae,
    create_patient_connections,
    generate_trajectories,
    build_trajectory_network,
    dynamic_enrichment_analysis
)
from renalprog.plots import plot_trajectory, plot_latent_space

# 1. Load model and data
model = VAE(input_dim=20000, mid_dim=1024, features=128)
model.load_state_dict(torch.load("models/my_vae/best_model.pt"))

expr = pd.read_csv("data/interim/split/test_expression.tsv", sep="\t", index_col=0)
clinical = pd.read_csv("data/interim/split/test_clinical.tsv", sep="\t", index_col=0)

# 2. Encode to latent space
results = apply_vae(model, expr.values, device='cuda')
latent = results['latent']

# 3. Split by stage
early_mask = clinical['stage'] == 'early'
late_mask = clinical['stage'] == 'late'

# 4. Create patient connections
connections = create_patient_connections(
    latent_early=latent[early_mask],
    latent_late=latent[late_mask],
    method='closest',
    output_path=Path("data/processed/connections.csv")
)

# 5. Generate trajectories
trajectories = generate_trajectories(
    model=model,
    start_data=expr.values[early_mask],
    end_data=expr.values[late_mask],
    n_steps=50,
    interpolation='spherical',
    device='cuda'
)

# 6. Build trajectory network
network = build_trajectory_network(
    connections=connections,
    output_path=Path("data/processed/network.graphml")
)

# 7. Dynamic enrichment analysis
enrichment = dynamic_enrichment_analysis(
    trajectories=trajectories,
    gene_names=expr.columns.tolist(),
    pathway_file=Path("data/external/ReactomePathways.gmt"),
    output_dir=Path("reports/enrichment")
)

# 8. Visualize
plot_trajectory(
    trajectory=trajectories[0],
    feature_names=expr.columns[:20],  # Top 20 genes
    output_path=Path("reports/figures/trajectory_001.png")
)

print(f"Generated {len(trajectories)} trajectories")
print(f"Network edges: {network.number_of_edges()}")
print(f"Enrichment timepoints: {len(enrichment)}")

See Also