Skip to content

Enrichment Analysis

Pathway enrichment analysis using PyDESeq2 and GSEA.

Overview

The enrichment module provides tools for:

  1. Differential Expression Analysis: Using PyDESeq2 to identify significantly changed genes
  2. Gene Set Enrichment Analysis (GSEA): Pathway-level analysis using GSEA CLI
  3. Pathway Visualization: Heatmap generation for pathway enrichment across trajectories

Main Classes

EnrichmentPipeline

EnrichmentPipeline(
    trajectory_dir: str,
    output_dir: str,
    cancer_type: str = "kirc",
    data_dir: Optional[str] = None,
    metadata_dir: Optional[str] = None,
    control_data_dir: Optional[str] = None,
    control_metadata_dir: Optional[str] = None,
    gsea_path: str = "./GSEA_4.3.2/gsea-cli.sh",
    pathways_file: str = "data/external/ReactomePathways.gmt",
    n_threads: int = 4,
)

Main pipeline for dynamic enrichment analysis using PyDESeq2 and GSEA.

This class orchestrates the complete enrichment analysis workflow:

  1. PyDESeq2 Differential Expression Analysis
  2. Converts log2(RSEM+1) data back to integer RSEM counts
  3. Runs PyDESeq2 for each trajectory timepoint vs controls
  4. Generates statistically valid log2FoldChange values

  5. GSEA Pathway Enrichment

  6. Creates ranked gene lists (.rnk files) from DESeq2 results
  7. Executes GSEA in parallel with ReactomePathways gene sets
  8. Collects pathway enrichment scores (NES, p-values, FDR)

  9. Result Processing and Visualization

  10. Combines GSEA results across all trajectories and timepoints
  11. Generates pathway enrichment heatmaps

IMPORTANT: This pipeline uses PyDESeq2 for proper differential expression. DO NOT bypass this with simple fold-change calculations.

Initialize enrichment pipeline.

Args: trajectory_dir: Directory containing trajectory CSV files output_dir: Output directory for enrichment results cancer_type: Cancer type ('kirc', 'lobular', 'ductal') data_dir: Path to preprocessed RNA-seq data metadata_dir: Path to clinical metadata control_data_dir: Path to control RNA-seq data control_metadata_dir: Path to control metadata gsea_path: Path to GSEA CLI tool pathways_file: Path to pathways GMT file n_threads: Number of parallel threads

Source code in renalprog/enrichment.py
def __init__(
    self,
    trajectory_dir: str,
    output_dir: str,
    cancer_type: str = 'kirc',
    data_dir: Optional[str] = None,
    metadata_dir: Optional[str] = None,
    control_data_dir: Optional[str] = None,
    control_metadata_dir: Optional[str] = None,
    gsea_path: str = './GSEA_4.3.2/gsea-cli.sh',
    pathways_file: str = 'data/external/ReactomePathways.gmt',
    n_threads: int = 4
):
    """
    Initialize enrichment pipeline.

    Args:
        trajectory_dir: Directory containing trajectory CSV files
        output_dir: Output directory for enrichment results
        cancer_type: Cancer type ('kirc', 'lobular', 'ductal')
        data_dir: Path to preprocessed RNA-seq data
        metadata_dir: Path to clinical metadata
        control_data_dir: Path to control RNA-seq data
        control_metadata_dir: Path to control metadata
        gsea_path: Path to GSEA CLI tool
        pathways_file: Path to pathways GMT file
        n_threads: Number of parallel threads
    """
    self.trajectory_dir = Path(trajectory_dir)
    self.output_dir = Path(output_dir)
    self.cancer_type = cancer_type
    self.n_threads = n_threads
    self.gsea_path = Path(gsea_path)
    self.pathways_file = Path(pathways_file)

    # Set default data paths if not provided
    if data_dir is None:
        # Find latest preprocessed data
        data_dir = self._find_latest_preprocessed_data()
    if metadata_dir is None:
        metadata_dir = self._find_latest_preprocessed_metadata()
    if control_data_dir is None:
        control_data_dir = PATHS['processed'] / 'controls' / 'KIRC' / 'rnaseq_control.csv'
    if control_metadata_dir is None:
        control_metadata_dir = PATHS['processed'] / 'controls' / 'KIRC' / 'clinical_control.csv'

    self.data_dir = Path(data_dir)
    self.metadata_dir = Path(metadata_dir)
    self.control_data_dir = Path(control_data_dir)
    self.control_metadata_dir = Path(control_metadata_dir)

    # Create output directories
    self.output_dir.mkdir(parents=True, exist_ok=True)
    self.deseq_dir = self.output_dir / 'deseq'
    self.gsea_dir = self.output_dir / 'gsea'
    self.deseq_dir.mkdir(exist_ok=True)
    self.gsea_dir.mkdir(exist_ok=True)

    logger.info(f"Initialized EnrichmentPipeline for {cancer_type}")
    logger.info(f"  Trajectory dir: {self.trajectory_dir}")
    logger.info(f"  Output dir: {self.output_dir}")
    logger.info(f"  Threads: {self.n_threads}")

Functions

run

run(skip_deseq: bool = False, skip_gsea: bool = False, cleanup: bool = False)

Run the complete enrichment pipeline.

Args: skip_deseq: Skip DESeq processing (use if already completed) skip_gsea: Skip GSEA analysis (use if already completed) cleanup: Remove intermediate files after processing

Source code in renalprog/enrichment.py
def run(
    self,
    skip_deseq: bool = False,
    skip_gsea: bool = False,
    cleanup: bool = False
):
    """
    Run the complete enrichment pipeline.

    Args:
        skip_deseq: Skip DESeq processing (use if already completed)
        skip_gsea: Skip GSEA analysis (use if already completed)
        cleanup: Remove intermediate files after processing
    """
    logger.info("Starting enrichment analysis pipeline")

    # Step 1: Process trajectories for DESeq
    if not skip_deseq:
        logger.info("Step 1/4: Processing trajectories for DESeq...")
        self._run_deseq_processing()
    else:
        logger.info("Step 1/4: Skipping DESeq processing")

    # Step 2: Run GSEA in parallel
    if not skip_gsea:
        logger.info("Step 2/4: Running GSEA analysis...")
        self._run_gsea_parallel()
    else:
        logger.info("Step 2/4: Skipping GSEA analysis")

    # Step 3: Combine GSEA results
    logger.info("Step 3/4: Combining GSEA results...")
    final_df = self._combine_gsea_results()

    # Save final dataset
    output_path = self.output_dir / 'trajectory_enrichment.csv'
    final_df.to_csv(output_path, index=False)
    logger.info(f"Final enrichment dataset saved to: {output_path}")
    logger.info(f"Dataset shape: {final_df.shape}")

    # Step 4: Generate pathway enrichment heatmap
    logger.info("Step 4/4: Generating pathway enrichment heatmap...")
    try:
        heatmap_data, heatmap_fig = generate_pathway_heatmap(
            enrichment_df=final_df,
            output_dir=self.output_dir
        )
        logger.info(f"Heatmap generated with {heatmap_data.shape[0]} significant pathways")
    except Exception as e:
        logger.error(f"Failed to generate heatmap: {e}", exc_info=True)
        logger.warning("Continuing without heatmap...")

    # Cleanup if requested
    if cleanup:
        logger.info("Cleaning up intermediate files...")
        self._cleanup()

    return final_df

_run_deseq_processing

_run_deseq_processing()

Process all trajectory files for DESeq analysis.

Source code in renalprog/enrichment.py
def _run_deseq_processing(self):
    """Process all trajectory files for DESeq analysis."""
    # Find all trajectory CSV files
    trajectory_files = list(self.trajectory_dir.glob('*.csv'))

    if not trajectory_files:
        raise ValueError(f"No trajectory files found in {self.trajectory_dir}")

    logger.info(f"Found {len(trajectory_files)} trajectory files")

    # Load data once
    rnaseq_data, clinical_data, control_data, control_metadata, gene_list = self._load_data()

    # Process files in parallel
    with ProcessPoolExecutor(max_workers=self.n_threads) as executor:
        futures = []
        for traj_file in trajectory_files:
            future = executor.submit(
                process_trajectory_file,
                traj_file=traj_file,
                rnaseq_data=rnaseq_data,
                clinical_data=clinical_data,
                control_data=control_data,
                control_metadata=control_metadata,
                gene_list=gene_list,
                output_dir=self.deseq_dir,
                cancer_type=self.cancer_type,
                gsea_path=self.gsea_path,
                pathways_file=self.pathways_file
            )
            futures.append(future)

        # Track progress
        for future in tqdm(as_completed(futures), total=len(futures), desc="DESeq processing"):
            try:
                future.result()
            except Exception as e:
                logger.error(f"Error processing file: {e}")

    logger.info("DESeq processing complete")

    # Validate that files were created
    rnk_files = list(self.deseq_dir.rglob('*.rnk'))
    cmd_files = list(self.deseq_dir.glob('*.cmd'))

    logger.info(f"Created {len(rnk_files)} .rnk files")
    logger.info(f"Created {len(cmd_files)} .cmd files")

    if len(rnk_files) == 0:
        logger.error("No .rnk files were created during DESeq processing")
        raise ValueError("DESeq processing failed to create rank files")

    if len(cmd_files) == 0:
        logger.error("No .cmd files were created during DESeq processing")
        raise ValueError("DESeq processing failed to create command files")

_combine_gsea_results

_combine_gsea_results() -> pd.DataFrame

Combine all GSEA results into a single dataset.

Source code in renalprog/enrichment.py
def _combine_gsea_results(self) -> pd.DataFrame:
    """Combine all GSEA results into a single dataset."""
    # Load pathways
    pathways = load_pathways_from_gmt(self.pathways_file)

    # Diagnostic: show directory structure
    logger.info(f"Scanning GSEA results in: {self.deseq_dir}")
    logger.info("Directory structure:")
    for item in self.deseq_dir.iterdir():
        if item.is_dir():
            logger.info(f"  {item.name}/")
            for subitem in item.iterdir():
                if subitem.is_dir():
                    gsea_dirs = list(subitem.glob('gsea_tp*'))
                    logger.info(f"    {subitem.name}/ ({len(gsea_dirs)} gsea_tp* dirs)")
                    # Show deeper structure for debugging
                    for gsea_dir in gsea_dirs[:2]:  # Show first 2
                        subdirs = [d.name for d in gsea_dir.iterdir() if d.is_dir()]
                        files = [f.name for f in gsea_dir.iterdir() if f.is_file()]
                        logger.debug(f"      {gsea_dir.name}/ subdirs={subdirs[:3]}, files={files[:3]}")
        else:
            logger.info(f"  {item.name}")

    # Find all trajectory directories (transition/patient structure)
    # First, look for transition directories
    transition_dirs = [d for d in self.deseq_dir.iterdir() if d.is_dir() and not d.name.startswith('.')]

    logger.info(f"Found {len(transition_dirs)} transition directories")

    all_results = []
    failed_trajectories = []
    skipped_trajectories = []

    for transition_dir in transition_dirs:
        # Find patient directories within each transition
        patient_dirs = [d for d in transition_dir.iterdir() if d.is_dir() and not d.name.startswith('.')]
        logger.info(f"  Found {len(patient_dirs)} patient directories in {transition_dir.name}")

        for patient_dir in tqdm(patient_dirs, desc=f"Processing {transition_dir.name}"):
            try:
                result = process_trajectory_results(patient_dir, pathways)
                if result is not None:
                    all_results.append(result)
                else:
                    skipped_trajectories.append(str(patient_dir))
                    logger.warning(f"No results for {patient_dir}")
            except Exception as e:
                failed_trajectories.append((str(patient_dir), str(e)))
                logger.error(f"Error processing {patient_dir}: {e}", exc_info=True)

    # Report statistics
    logger.info(f"Successfully processed: {len(all_results)} trajectories")
    logger.info(f"Skipped (no GSEA dirs): {len(skipped_trajectories)} trajectories")
    logger.info(f"Failed (exceptions): {len(failed_trajectories)} trajectories")

    if failed_trajectories:
        logger.warning("Failed trajectories:")
        for path, error in failed_trajectories[:10]:  # Show first 10
            logger.warning(f"  {path}: {error}")

    if not all_results:
        logger.error(f"No GSEA results found in {self.deseq_dir}")
        logger.error("Directory structure should be: deseq_dir/transition/patient/gsea_tp*/")
        logger.error(f"Transition dirs found: {len(transition_dirs)}")
        logger.error(f"Total patient dirs scanned: {len(skipped_trajectories) + len(failed_trajectories)}")

        # Provide more specific error message
        if failed_trajectories:
            logger.error(f"All trajectories failed with errors. First error: {failed_trajectories[0][1]}")
        elif skipped_trajectories:
            logger.error("All trajectories were skipped (no gsea_tp* directories found)")

        raise ValueError("No GSEA results found to combine")

    # Combine all results
    final_df = pd.concat(all_results, axis=0, ignore_index=True)

    return final_df

Functions

Differential Expression

run_deseq2_analysis

run_deseq2_analysis(
    sample_data: DataFrame,
    control_data: DataFrame,
    control_metadata: DataFrame,
    gene_list: ndarray,
    sample_name: str,
    stage_transition: str,
) -> pd.Series

Perform DESeq2 differential expression analysis between sample and controls.

This function properly: 1. Converts log2(RSEM+1) data back to RSEM integer counts 2. Runs PyDESeq2 analysis to get statistically valid log2FoldChange values 3. Returns ranked gene list for GSEA

Args: sample_data: Sample expression data (genes x 1) in log2(RSEM+1) format control_data: Control expression data (genes x samples) in log2(RSEM+1) format control_metadata: Control clinical metadata with stage information gene_list: List of genes sample_name: Name/ID of the sample stage_transition: Stage transition label (e.g., 'early_to_late', 'I_to_II')

Returns: Series of log2FoldChange values sorted for GSEA input

Source code in renalprog/enrichment.py
def run_deseq2_analysis(
    sample_data: pd.DataFrame,
    control_data: pd.DataFrame,
    control_metadata: pd.DataFrame,
    gene_list: np.ndarray,
    sample_name: str,
    stage_transition: str
) -> pd.Series:
    """
    Perform DESeq2 differential expression analysis between sample and controls.

    This function properly:
    1. Converts log2(RSEM+1) data back to RSEM integer counts
    2. Runs PyDESeq2 analysis to get statistically valid log2FoldChange values
    3. Returns ranked gene list for GSEA

    Args:
        sample_data: Sample expression data (genes x 1) in log2(RSEM+1) format
        control_data: Control expression data (genes x samples) in log2(RSEM+1) format
        control_metadata: Control clinical metadata with stage information
        gene_list: List of genes
        sample_name: Name/ID of the sample
        stage_transition: Stage transition label (e.g., 'early_to_late', 'I_to_II')

    Returns:
        Series of log2FoldChange values sorted for GSEA input
    """
    # Debug logging
    logger.debug(f"DESeq2 analysis for {sample_name}")
    logger.debug(f"  Sample data shape: {sample_data.shape}")
    logger.debug(f"  Control data shape: {control_data.shape}")
    logger.debug(f"  Gene list length: {len(gene_list)}")

    # 1. Convert from log2(RSEM+1) back to RSEM counts
    # The preprocessed data is in log2(RSEM+1) format, so we reverse this transformation
    # and round to integers as required by DESeq2
    sample_counts = 2 ** sample_data.values.flatten() - 1
    control_counts = 2 ** control_data.values - 1

    # Clip negative values to 0 (can occur due to numerical precision or interpolation)
    sample_counts = np.clip(sample_counts, 0, None)
    control_counts = np.clip(control_counts, 0, None)

    # Round to integers as required by DESeq2
    sample_counts = np.round(sample_counts).astype(int)
    control_counts = np.round(control_counts).astype(int)

    # Ensure control_counts is 2D (genes x samples)
    if control_counts.ndim == 1:
        control_counts = control_counts[:, np.newaxis]

    # Validate counts
    if np.any(sample_counts < 0):
        logger.error(f"Sample has negative counts after conversion: min={sample_counts.min()}")
        raise ValueError("Sample counts contain negative values after conversion")
    if np.any(control_counts < 0):
        logger.error(f"Controls have negative counts after conversion: min={control_counts.min()}")
        raise ValueError("Control counts contain negative values after conversion")

    logger.debug(f"  Sample counts: min={sample_counts.min()}, max={sample_counts.max()}, mean={sample_counts.mean():.1f}")
    logger.debug(f"  Control counts: min={control_counts.min()}, max={control_counts.max()}, mean={control_counts.mean():.1f}")

    # Ensure sample_counts has the same number of genes as control_counts
    if len(sample_counts) != control_counts.shape[0]:
        raise ValueError(
            f"Sample has {len(sample_counts)} genes but controls have {control_counts.shape[0]} genes. "
            f"Sample shape: {sample_data.shape}, Control shape: {control_data.shape}"
        )

    # 2. Combine sample with controls
    # Create count matrix: genes (rows) x samples (columns)
    # sample_counts is 1D (n_genes,), control_counts is 2D (n_genes, n_controls)
    # We need to reshape sample_counts to (n_genes, 1) to stack horizontally
    counts_matrix = np.column_stack([sample_counts.reshape(-1, 1), control_counts])

    counts_df = pd.DataFrame(
        counts_matrix,
        index=gene_list,
        columns=[sample_name] + list(control_data.columns)
    )

    # 3. Create metadata DataFrame with condition labels
    metadata_df = pd.DataFrame({
        'condition': [stage_transition] + list(control_metadata['ajcc_pathologic_tumor_stage'].values)
    }, index=counts_df.columns)

    # 4. Run PyDESeq2 analysis
    # Initialize DESeqDataSet
    dds = DeseqDataSet(
        counts=counts_df.T,  # DESeq2 expects samples as rows, genes as columns
        metadata=metadata_df,
        design_factors='condition',
        refit_cooks=True,
        quiet=True  # Suppress verbose output in parallel processing
    )

    # Fit dispersions and log-fold changes
    dds.deseq2()

    # Statistical analysis
    stat_res = DeseqStats(
        dds,
        alpha=0.05,
        cooks_filter=True,
        independent_filter=True,
        quiet=True
    )
    stat_res.summary()

    # 5. Extract log2FoldChange values
    results_df = stat_res.results_df

    # Ensure we have log2FoldChange column
    if 'log2FoldChange' not in results_df.columns:
        raise ValueError(f"DESeq2 results missing 'log2FoldChange' column. Available columns: {results_df.columns.tolist()}")

    # Extract log2FoldChange and sort by absolute value (for GSEA ranking)
    log2fc = results_df['log2FoldChange'].copy()

    # Replace NaN values with 0 (genes with no change or insufficient data)
    log2fc = log2fc.fillna(0)

    # Sort by absolute value descending (GSEA expects ranked list)
    log2fc = log2fc.sort_values(ascending=False)

    logger.debug(f"  DESeq2 complete: {len(log2fc)} genes, range [{log2fc.min():.2f}, {log2fc.max():.2f}]")

    return log2fc

GSEA Analysis

run_gsea_command

run_gsea_command(cmd: str) -> bool

Run a single GSEA command.

Args: cmd: GSEA command string

Returns: True if successful, False otherwise

Source code in renalprog/enrichment.py
def run_gsea_command(cmd: str) -> bool:
    """
    Run a single GSEA command.

    Args:
        cmd: GSEA command string

    Returns:
        True if successful, False otherwise
    """
    try:
        result = subprocess.run(
            cmd,
            shell=True,
            capture_output=True,
            text=True,
            timeout=600  # 10 minute timeout per command
        )

        if result.returncode != 0:
            logger.error(f"GSEA command failed with return code {result.returncode}")
            logger.error(f"Command: {cmd[:200]}...")  # Truncate long commands
            logger.error(f"STDERR: {result.stderr[:500]}")  # First 500 chars
            if result.stdout:
                logger.error(f"STDOUT: {result.stdout[:500]}")
            return False

        # Check for Java errors in stdout (GSEA sometimes returns 0 even on error)
        if result.stdout and ("error" in result.stdout.lower() or "exception" in result.stdout.lower()):
            logger.warning(f"GSEA command may have failed (found error in output)")
            logger.warning(f"Command: {cmd[:200]}...")
            logger.warning(f"Output: {result.stdout[:500]}")  # First 500 chars
            return False

        # Check for GSEA-specific success indicators
        if result.stdout and "Enrichment score" in result.stdout:
            logger.debug(f"GSEA command completed successfully")
            return True
        elif result.stdout:
            # Log a sample of output for debugging
            logger.debug(f"GSEA output sample: {result.stdout[:200]}")

        return True

    except subprocess.TimeoutExpired:
        logger.error(f"GSEA command timed out after 600s")
        logger.error(f"Command: {cmd[:200]}...")
        return False
    except Exception as e:
        logger.error(f"Error running GSEA command: {e}")
        logger.error(f"Command: {cmd[:200]}...")
        return False

Visualization

generate_pathway_heatmap

generate_pathway_heatmap(
    enrichment_df: DataFrame,
    output_dir: str,
    fdr_threshold: float = 0.05,
    colorbar: bool = True,
    legend: bool = False,
    yticks_fontsize: int = 12,
    show: bool = False,
) -> Tuple[pd.DataFrame, Dict[str, matplotlib.figure.Figure]]

Generate multiple pathway enrichment heatmaps from GSEA results.

This function creates several heatmaps showing the sum of NES (Normalized Enrichment Score) across all trajectories for each pathway at each timepoint:

  1. Top 50 most changing pathways (first vs last timepoint)
  2. Top 50 most upregulated pathways (average NES > 0)
  3. Top 50 most downregulated pathways (average NES < 0)
  4. Selected pathways (high-level Reactome + literature pathways)

The heatmaps have: - Rows: Pathway names - Columns: Timepoints (pseudo-time from early to late) - Values: Sum of NES across all trajectories at each timepoint

Args: enrichment_df: DataFrame with columns [Patient, Idx, Transition, NAME, ES, NES, FDR q-val] output_dir: Output directory for heatmap files fdr_threshold: FDR q-value threshold for significance (default: 0.05) colorbar: Whether to show colorbar (default: True) legend: Whether to show legend (default: False) yticks_fontsize: Font size for y-axis tick labels (default: 12) show: Whether to display the plot (default: False)

Returns: Tuple of (heatmap_data, figures_dict): - heatmap_data: DataFrame with summed NES values (pathways × timepoints) - figures_dict: Dictionary mapping figure names to Matplotlib Figure objects

Example: >>> enrichment_df = pd.read_csv('trajectory_enrichment.csv') >>> heatmap_data, figs = generate_pathway_heatmap( ... enrichment_df=enrichment_df, ... output_dir='results/', ... fdr_threshold=0.05 ... ) >>> print(f"Generated {len(figs)} heatmaps")

Source code in renalprog/enrichment.py
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
def generate_pathway_heatmap(
    enrichment_df: pd.DataFrame,
    output_dir: str,
    fdr_threshold: float = 0.05,
    colorbar: bool = True,
    legend: bool = False,
    yticks_fontsize: int = 12,
    show: bool = False
) -> Tuple[pd.DataFrame, Dict[str, 'matplotlib.figure.Figure']]:
    """
    Generate multiple pathway enrichment heatmaps from GSEA results.

    This function creates several heatmaps showing the sum of NES (Normalized Enrichment Score)
    across all trajectories for each pathway at each timepoint:

    1. Top 50 most changing pathways (first vs last timepoint)
    2. Top 50 most upregulated pathways (average NES > 0)
    3. Top 50 most downregulated pathways (average NES < 0)
    4. Selected pathways (high-level Reactome + literature pathways)

    The heatmaps have:
    - Rows: Pathway names
    - Columns: Timepoints (pseudo-time from early to late)
    - Values: Sum of NES across all trajectories at each timepoint

    Args:
        enrichment_df: DataFrame with columns [Patient, Idx, Transition, NAME, ES, NES, FDR q-val]
        output_dir: Output directory for heatmap files
        fdr_threshold: FDR q-value threshold for significance (default: 0.05)
        colorbar: Whether to show colorbar (default: True)
        legend: Whether to show legend (default: False)
        yticks_fontsize: Font size for y-axis tick labels (default: 12)
        show: Whether to display the plot (default: False)

    Returns:
        Tuple of (heatmap_data, figures_dict):
            - heatmap_data: DataFrame with summed NES values (pathways × timepoints)
            - figures_dict: Dictionary mapping figure names to Matplotlib Figure objects

    Example:
        >>> enrichment_df = pd.read_csv('trajectory_enrichment.csv')
        >>> heatmap_data, figs = generate_pathway_heatmap(
        ...     enrichment_df=enrichment_df,
        ...     output_dir='results/',
        ...     fdr_threshold=0.05
        ... )
        >>> print(f"Generated {len(figs)} heatmaps")
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import matplotlib.colors as mcolors

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info(f"Generating pathway enrichment heatmaps (FDR < {fdr_threshold})...")

    # Define pathway lists
    highest_pathways = [
        "Autophagy",
        "Cell Cycle",
        "Cell-Cell communication",
        "Cellular responses to stimuli",
        "Chromatin organization",
        "Circadian Clock",
        "DNA Repair",
        "DNA Replication",
        "Developmental Biology",
        "Digestion and absorption",
        "Disease",
        "Drug ADME",
        "Extracellular matrix organization",
        "Gene expression (Transcription)",
        "Hemostasis",
        "Immune System",
        "Metabolism",
        "Metabolism of RNA",
        "Metabolism of proteins",
        "Muscle contraction",
        "Neuronal System",
        "Organelle biogenesis and maintenance",
        "Programmed Cell Death",
        "Protein localization",
        "Reproduction",
        "Sensory Perception",
        "Signal Transduction",
        "Transport of small molecules",
        "Vesicle-mediated transport",
    ]

    pathways_literature = [
        # VHL/HIF pathway
        "CELLULAR RESPONSE TO HYPOXIA",
        "OXYGEN-DEPENDENT PROLINE HYDROXYLATION OF HYPOXIA-INDUCIBLE FACTOR ALPHA",
        "REGULATION OF GENE EXPRESSION BY HYPOXIA-INDUCIBLE FACTOR",
        # PI3K/AKT/MTOR Pathway
        "PI3K/AKT ACTIVATION",
        "PI3K/AKT SIGNALING IN CANCER",
        "MTOR SIGNALLING",
        # Warburg effect
        "TP53 REGULATES METABOLIC GENES",
        "GLYCOLYSIS",
        "GLUCOSE METABOLISM",
        # TCA/Krebs cycle
        "CITRIC ACID CYCLE (TCA CYCLE)",
        "THE CITRIC ACID (TCA) CYCLE AND RESPIRATORY ELECTRON TRANSPORT",
        # Pentose phosphate pathway
        "NFE2L2 REGULATES PENTOSE PHOSPHATE PATHWAY GENES",
        "PENTOSE PHOSPHATE PATHWAY",
        "PENTOSE PHOSPHATE PATHWAY DISEASE",
        # Fatty Acid Metabolism
        "FATTY ACID METABOLISM",
        # Glutamine metabolism
        "GLUTAMATE AND GLUTAMINE METABOLISM",
        # EGFR
        "SIGNALING BY EGFR",
        "SIGNALING BY EGFR IN CANCER",
        "EGFR DOWNREGULATION",
        # TGF-β signaling
        "SIGNALING BY TGF-BETA RECEPTOR COMPLEX",
        "TGF-BETA RECEPTOR SIGNALING IN EMT (EPITHELIAL TO MESENCHYMAL TRANSITION)",
        "SIGNALING BY TGF-BETA RECEPTOR COMPLEX IN CANCER",
        "SIGNALING BY TGFB FAMILY MEMBERS",
        "TGF-BETA RECEPTOR SIGNALING ACTIVATES SMADS",
        # Wnt/β-catenin pathway
        "BETA-CATENIN INDEPENDENT WNT SIGNALING",
        "SIGNALING BY WNT",
        # SLIT-2-ROBO1 pathways
        "REGULATION OF EXPRESSION OF SLITS AND ROBOS",
        # DNA repair
        "DNA REPAIR",
        # Energy homeostasis
        "ION HOMEOSTASIS",
        # Apoptosis
        "APOPTOSIS",
        # Angiogenesis
        "SIGNALING BY VEGF"
    ]

    # Step 1: Ensure numeric types for NES and FDR q-val
    enrichment_df = enrichment_df.copy()
    enrichment_df['NES'] = pd.to_numeric(enrichment_df['NES'], errors='coerce')
    enrichment_df['FDR q-val'] = pd.to_numeric(enrichment_df['FDR q-val'], errors='coerce')
    enrichment_df['Idx'] = pd.to_numeric(enrichment_df['Idx'], errors='coerce')

    # Log any rows that had non-numeric values
    invalid_nes = enrichment_df['NES'].isna().sum()
    invalid_fdr = enrichment_df['FDR q-val'].isna().sum()
    if invalid_nes > 0:
        logger.warning(f"Found {invalid_nes} rows with non-numeric NES values (converted to NaN)")
    if invalid_fdr > 0:
        logger.warning(f"Found {invalid_fdr} rows with non-numeric FDR q-val values (converted to NaN)")

    # Step 2: Filter by FDR threshold
    significant = enrichment_df[enrichment_df['FDR q-val'] < fdr_threshold].copy()

    logger.info(f"Found {significant.shape[0]} significant pathway enrichments (FDR < {fdr_threshold})")

    if significant.empty:
        logger.warning("No significant pathways found. Cannot generate heatmap.")
        # Return empty results
        empty_df = pd.DataFrame()
        empty_dict = {}
        return empty_df, empty_dict

    # Step 3: Group by Timepoint (Idx) and Pathway (NAME), sum NES across all trajectories
    pathway_summary = significant.groupby(['Idx', 'NAME'])['NES'].sum().reset_index()

    logger.info(f"Aggregated results for {pathway_summary['Idx'].nunique()} timepoints "
                f"and {pathway_summary['NAME'].nunique()} pathways")

    # Step 4: Pivot to create matrix (pathways × timepoints)
    heatmap_data = pathway_summary.pivot(
        index='NAME',
        columns='Idx',
        values='NES'
    ).fillna(0)  # Fill missing with 0

    logger.info(f"Full heatmap dimensions: {heatmap_data.shape[0]} pathways × {heatmap_data.shape[1]} timepoints")

    # Save full summary data
    summary_file = output_dir / 'pathway_nes_summary.csv'
    heatmap_data.to_csv(summary_file)
    logger.info(f"Saved pathway NES summary to: {summary_file}")

    # Dictionary to store all figures
    figures = {}

    # Helper function to create and save a heatmap
    def plot_heatmap_regulation(df_plot, unique_pathways, cmap_here='viridis',
                               save_name=None, colorbar_title='Sum of NES'):
        """Plot heatmap following paper_figures.ipynb style"""
        # Generate a range of locations for the ticks
        tick_locations = range(len(unique_pathways))
        z_min, z_max = df_plot.min().min(), df_plot.max().max()

        # Make the range symmetric around 0
        if z_min < 0 and z_max > 0:
            abs_max = max(abs(z_min), abs(z_max))
            z_min, z_max = -abs_max, abs_max
            norm = mcolors.TwoSlopeNorm(vmin=z_min, vcenter=0, vmax=z_max)
        else:
            # If all values are positive or all negative, use regular normalization
            norm = mcolors.Normalize(vmin=z_min, vmax=z_max)
            logger.warning(f"Cannot center colormap at zero: range [{z_min:.3f}, {z_max:.3f}] does not cross zero")

        fig, ax = plt.subplots(figsize=(30, 10))

        # Make heatmap
        cax = ax.imshow(df_plot.values, cmap=cmap_here, norm=norm, aspect='auto')

        # Set the y-ticks
        plt.yticks(tick_locations, unique_pathways, fontsize=yticks_fontsize)

        # Set the x-ticks at specific positions
        num_timepoints = df_plot.shape[1]
        ax.set_xticks([0, num_timepoints - 1])
        ax.set_xticklabels(['early', 'late'], fontsize=yticks_fontsize*1.33, rotation=45)
        ax.set_xlabel('Pseudo-Time', fontsize=yticks_fontsize*1.33)

        # Get x and y axis range
        ymin, ymax = ax.get_ylim()
        xmin, xmax = ax.get_xlim()

        # Custom x and y ticks
        x_custom = np.arange(xmin, xmax, step=1)
        y_custom = np.arange(ymax, ymin, step=1)

        # set minor ticks at custom locations:
        ax.set_xticks(x_custom, minor=True)
        ax.set_yticks(y_custom, minor=True)

        # Add grid lines at both major and minor ticks
        plt.grid(False, which='major')
        plt.grid(True, which='minor', color='black', linestyle='-', linewidth=1)

        # Remove ticks
        ax.tick_params(axis='both', which='minor', length=0)

        # Add colorbar
        if colorbar:
            cbar = plt.colorbar(cax, shrink=0.7)
            cbar.set_label(colorbar_title, rotation=270, labelpad=20, fontsize=16)

        # Add legend
        if legend:
            colors = np.append(plt.get_cmap(cmap_here)([0, 0.5, 1]), np.array([[1, 1, 1, 1]]), axis=0)
            labels = ['Downregulated', 'No change', 'Upregulated', 'No data']
            patches = [
                mpatches.Patch(facecolor=colors[i], label=labels[i], edgecolor='black')
                for i in range(len(labels))
            ]
            ax.legend(handles=patches, bbox_to_anchor=(1.05, 1),
                      loc=2, borderaxespad=0., title='Regulation',
                      fontsize=yticks_fontsize*2, title_fontsize=24)

        # Save figures
        if save_name:
            plt.savefig(output_dir / f'{save_name}.pdf', bbox_inches='tight')
            plt.savefig(output_dir / f'{save_name}.png', bbox_inches='tight', dpi=600)
            plt.savefig(output_dir / f'{save_name}.svg', bbox_inches='tight',
                       format='svg', transparent=True)
            logger.info(f"Saved heatmap to: {output_dir / save_name}.{{pdf,png,svg}}")

        if show:
            plt.show()
        else:
            plt.close()

        return fig

    # 1. Top 50 most changing pathways (first vs last timepoint)
    logger.info("Creating heatmap 1/5: Top 50 most changing pathways...")
    first_col = heatmap_data.columns[0]
    last_col = heatmap_data.columns[-1]
    change = (heatmap_data[last_col] - heatmap_data[first_col]).abs()
    top_changing = change.nlargest(50).index.tolist()

    df_top_changing = heatmap_data.loc[top_changing]
    fig1 = plot_heatmap_regulation(
        df_top_changing,
        top_changing,
        cmap_here='RdBu_r',
        save_name='top50_most_changing_pathways',
        colorbar_title='Sum of NES'
    )
    figures['top50_changing'] = fig1

    # 2. Top 50 most upregulated pathways (average NES > 0)
    logger.info("Creating heatmap 2/5: Top 50 most upregulated pathways...")
    avg_nes = heatmap_data.mean(axis=1)
    upregulated = avg_nes[avg_nes > 0].nlargest(50).index.tolist()

    df_upregulated = heatmap_data.loc[upregulated]
    fig2 = plot_heatmap_regulation(
        df_upregulated,
        upregulated,
        cmap_here='YlGn',
        save_name='top50_most_upregulated_pathways',
        colorbar_title='Sum of NES'
    )
    figures['top50_upregulated'] = fig2

    # 3. Top 50 most downregulated pathways (average NES < 0)
    logger.info("Creating heatmap 3/5: Top 50 most downregulated pathways...")
    downregulated = avg_nes[avg_nes < 0].nsmallest(50).index.tolist()

    df_downregulated = heatmap_data.loc[downregulated]
    fig3 = plot_heatmap_regulation(
        df_downregulated,
        downregulated,
        cmap_here='YlOrBr',
        save_name='top50_most_downregulated_pathways',
        colorbar_title='Sum of NES'
    )
    figures['top50_downregulated'] = fig3

    # 4. High-level pathways (29 pathways from Reactome highest level)
    logger.info("Creating heatmap 4/5: High-level pathways...")
    available_highest = [p for p in highest_pathways if p in heatmap_data.index]

    if available_highest:
        df_highest = heatmap_data.loc[available_highest]
        fig4 = plot_heatmap_regulation(
            df_highest,
            available_highest,
            cmap_here='RdBu_r',
            save_name='selected_pathways_highest_level',
            colorbar_title='Sum of NES'
        )
        figures['selected_highest_level'] = fig4
        logger.info(f"Found {len(available_highest)}/{len(highest_pathways)} high-level pathways in data")
    else:
        logger.warning("No high-level pathways found in the data")

    # 5. Literature pathways (33 pathways from literature review)
    logger.info("Creating heatmap 5/5: Literature pathways...")
    available_literature = [p for p in pathways_literature if p in heatmap_data.index]

    if available_literature:
        df_literature = heatmap_data.loc[available_literature]
        fig5 = plot_heatmap_regulation(
            df_literature,
            available_literature,
            cmap_here='RdBu_r',
            save_name='selected_pathways_literature',
            colorbar_title='Sum of NES'
        )
        figures['selected_literature'] = fig5
        logger.info(f"Found {len(available_literature)}/{len(pathways_literature)} literature pathways in data")
    else:
        logger.warning("No literature pathways found in the data")

    logger.info(f"Pathway heatmap generation complete. Created {len(figures)} heatmaps.")

    return heatmap_data, figures

Usage Examples

Running Complete Enrichment Pipeline

from renalprog.enrichment import EnrichmentPipeline
from pathlib import Path

# Initialize pipeline
pipeline = EnrichmentPipeline(
    cancer_type="KIRC",
    trajectory_dir=Path("data/processed/trajectories"),
    output_dir=Path("data/processed/enrichment"),
    gsea_path=Path("./GSEA_4.3.2/gsea-cli.sh"),
    pathways_file=Path("data/external/ReactomePathways.gmt"),
    n_threads=8,
    n_threads_per_deseq=8
)

# Run full pipeline
results = pipeline.run()

Running DESeq2 Analysis

from renalprog.enrichment import run_deseq2_analysis
import pandas as pd

# Load trajectory data
trajectory_data = pd.read_csv("trajectory_001.csv", index_col=0)
control_data = pd.read_csv("control.csv", index_col=0)

# Run DESeq2
results_df = run_deseq2_analysis(
    trajectory_samples=trajectory_data,
    control_samples=control_data,
    n_threads=8
)

# Results contain: log2FoldChange, pvalue, padj, etc.
print(results_df.head())

Creating RNK Files for GSEA

from renalprog.enrichment import create_rnk_file

# Create ranked gene list from DESeq2 results
rnk_file = create_rnk_file(
    deseq_results=results_df,
    output_path="analysis/genes.rnk"
)

Running GSEA

from renalprog.enrichment import run_gsea_command

# Run GSEA on ranked gene list
gsea_output = run_gsea_command(
    rnk_file="analysis/genes.rnk",
    gmt_file="data/external/ReactomePathways.gmt",
    output_dir="analysis/gsea_results",
    label="trajectory_001",
    gsea_path="./GSEA_4.3.2/gsea-cli.sh"
)

Generating Pathway Heatmaps

from renalprog.enrichment import generate_pathway_heatmap

# Generate heatmaps from enrichment results
heatmap_data, figures = generate_pathway_heatmap(
    enrichment_file="data/processed/enrichment/trajectory_enrichment.csv",
    output_dir="data/processed/enrichment",
    fdr_threshold=0.05,
    n_timepoints=50
)

# figures contains:
# - "top_50_changing": Most variable pathways
# - "top_50_upregulated": Most upregulated pathways
# - "top_50_downregulated": Most downregulated pathways
# - "high_level": Reactome high-level pathways
# - "literature": Literature-curated pathways

Configuration

EnrichmentPipeline Parameters

  • cancer_type: Cancer type identifier (e.g., "KIRC", "BRCA")
  • trajectory_dir: Directory containing trajectory CSV files
  • output_dir: Directory for output files
  • gsea_path: Path to GSEA CLI executable
  • pathways_file: Path to GMT file with pathway definitions
  • n_threads: Number of parallel threads for processing
  • n_threads_per_deseq: Number of threads per DESeq2 job
  • memory_per_job_gb: Memory limit per DESeq2 job (default: 12 GB)
  • total_memory_gb: Total available memory (default: 224 GB)

GSEA Parameters

  • nperm: Number of permutations (default: 1000)
  • set_min: Minimum gene set size (default: 15)
  • set_max: Maximum gene set size (default: 500)
  • scoring_scheme: GSEA scoring method (default: "weighted")
  • norm: Normalization method (default: "meandiv")

Pathway Collections

High-Level Reactome Pathways

29 top-level biological processes: - Autophagy - Cell Cycle - DNA Repair - Immune System - Metabolism - Signal Transduction - And more...

Literature-Curated Pathways

33 pathways from literature review: - VHL/HIF pathway - PI3K/AKT/MTOR pathway - Warburg effect - TCA cycle - And more...

Output Files

DESeq2 Results

  • {trajectory_id}_deseq_results.csv: Complete DESeq2 output
  • {trajectory_id}.rnk: Ranked gene list for GSEA

GSEA Results

  • {trajectory_id}/gsea_report_for_na_pos_{timestamp}.tsv: Positive enrichment
  • {trajectory_id}/gsea_report_for_na_neg_{timestamp}.tsv: Negative enrichment
  • {trajectory_id}/ranked_gene_list_{timestamp}.tsv: Ranked genes with scores

Combined Results

  • trajectory_enrichment.csv: All GSEA results combined
  • pathway_heatmap_*.png/pdf/svg: Pathway heatmap visualizations

Performance Considerations

Memory Management

DESeq2 analysis is memory-intensive. The pipeline automatically: - Limits concurrent jobs based on available memory - Allocates memory per job (default: 12 GB) - Monitors memory usage

For large datasets:

pipeline = EnrichmentPipeline(
    ...,
    n_threads=8,  # Reduce parallelism
    n_threads_per_deseq=4,  # Reduce threads per job
    memory_per_job_gb=16  # Increase memory per job
)

CPU Utilization

  • DESeq2 jobs run in parallel (up to n_threads)
  • Each job uses n_threads_per_deseq threads
  • Total CPU usage ≈ n_threads × n_threads_per_deseq

Recommended settings: - Small dataset (<100 trajectories): n_threads=8, n_threads_per_deseq=8 - Large dataset (>100 trajectories): n_threads=4, n_threads_per_deseq=12

Troubleshooting

Out of Memory Errors

# Reduce parallel jobs
pipeline = EnrichmentPipeline(..., n_threads=4)

# Increase memory per job
pipeline = EnrichmentPipeline(..., memory_per_job_gb=20)

GSEA Not Found

Ensure GSEA is installed and path is correct:

# Test GSEA
./GSEA_4.3.2/gsea-cli.sh --help

No Results Generated

Check logs for errors:

import logging
logging.basicConfig(level=logging.DEBUG)

See Also