Skip to content

Dataset API

The dataset module provides functions for downloading, loading, and preparing RNA-seq data for analysis.

Overview

This module handles:

  • Downloading TCGA Pan-Cancer Atlas data
  • Processing and filtering by cancer type
  • Creating train/test splits
  • Loading preprocessed data for modeling

Core Functions

download_data

Download KIRC datasets from TCGA Pan-Cancer Atlas.

download_data

download_data(
    destination: Path = "data/raw", remove_gz: bool = True, timeout: int = 300
) -> Tuple[Path, Path, Path]

Download the KIRC datasets to the specified destination.

Args: destination: Path to directory where dataset should be saved remove_gz: Whether to remove .gz files after extraction timeout: Request timeout in seconds

Source code in renalprog/dataset.py
def download_data(
        destination: Path = "data/raw",
        remove_gz: bool = True,
        timeout: int = 300) -> Tuple[Path,Path,Path]:
    """
    Download the KIRC datasets to the specified destination.

    Args:
        destination: Path to directory where dataset should be saved
        remove_gz: Whether to remove .gz files after extraction
        timeout: Request timeout in seconds
    """
    # Ensure save directory exists
    destination = Path(destination)
    destination.mkdir(parents=True, exist_ok=True)

    datasets = [
        ("https://tcga-pancan-atlas-hub.s3.us-east-1.amazonaws.com/download/"
         "EB%2B%2BAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena.gz",
         "Gene expression RNAseq - Batch effects normalized mRNA data"),
        ("https://tcga-pancan-atlas-hub.s3.us-east-1.amazonaws.com/download/"
         "TCGA_phenotype_denseDataOnlyDownload.tsv.gz",
         "TCGA phenotype data"),
        ("https://tcga-pancan-atlas-hub.s3.us-east-1.amazonaws.com/download/"
         "Survival_SupplementalTable_S1_20171025_xena_sp",
         "Clinical survival data")
    ]

    for url, description in datasets:
        try:
            filename = url.split('/')[-1]
            file_path = destination / filename
            is_gzipped = filename.endswith('.gz')

            logger.info(f"Downloading dataset: {description}...")
            response = requests.get(url, timeout=timeout, stream=True)
            response.raise_for_status()

            total_size = int(response.headers.get('content-length', 0))
            with open(file_path, 'wb') as file:
                downloaded = 0
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        file.write(chunk)
                        downloaded += len(chunk)
                        if total_size > 0:
                            percent = (downloaded / total_size) * 100
                            print(f"\rDownloading {filename}: {percent:.1f}%", end='', flush=True)

            print(f"\nFile downloaded successfully: {file_path}")

            if is_gzipped:
                extracted_path = destination / filename.replace('.gz', '')
                logger.info("Extracting compressed file...")
                with gzip.open(file_path, 'rb') as f_in, open(extracted_path, 'wb') as f_out:
                    f_out.write(f_in.read())
                logger.info(f"Successfully extracted file to: {extracted_path}")

                if remove_gz:
                    file_path.unlink()
                    logger.info("Removed compressed .gz file")

        except requests.RequestException as e:
            logger.error(f"Failed to download {description}: {e}")
            raise
        except IOError as e:
            logger.error(f"File I/O error for {description}: {e}")
            raise

    # Return paths to downloaded files
    rnaseq_path = destination / "EB%2B%2BAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena"
    clinical_path = destination / "Survival_SupplementalTable_S1_20171025_xena_sp"
    phenotype_path = destination / "TCGA_phenotype_denseDataOnlyDownload.tsv"

    return rnaseq_path, clinical_path, phenotype_path

Example Usage:

from renalprog.dataset import download_data
from pathlib import Path

# Download to default location
rnaseq_path, clinical_path, phenotype_path = download_data(
    destination=Path("data/raw"),
    remove_gz=True,
    timeout=300
)

print(f"RNA-seq data: {rnaseq_path}")
print(f"Clinical data: {clinical_path}")
print(f"Phenotype data: {phenotype_path}")

process_downloaded_data

Process downloaded TCGA data for a specific cancer type.

process_downloaded_data

process_downloaded_data(
    rnaseq_path: Path = "data/raw/EB%2B%2BAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena",
    clinical_path: Path = "data/raw/Survival_SupplementalTable_S1_20171025_xena_sp",
    phenotype_path: Path = "data/raw/TCGA_phenotype_denseDataOnlyDownload.tsv",
    cancer_type: str = "KIRC",
    output_dir: Path = "data/raw",
    early_late: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]

Process TCGA Pan-Cancer Atlas data for a specific cancer type.

This function performs comprehensive data processing and quality control on TCGA datasets, including cancer type filtering, sample type selection, stage harmonization, and optional binary classification mapping. The processing pipeline follows TCGA best practices for multi-omics data integration.

Processing Steps: 1. Load raw RNA-seq, clinical, and phenotype data from TCGA Xena Hub 2. Filter samples by cancer type (KIRC or BRCA) 3. Select primary tumor samples only (exclude metastatic/recurrent) 4. Remove ambiguous stage annotations (Stage X, discrepancies, missing) 5. Harmonize substages (e.g., Stage IA/IB/IC → Stage I) 6. Optionally map stages to binary early/late classification 7. Ensure consistency across all three datasets 8. Save processed data in CSV format

Args: rnaseq_path: Path to RNA-seq expression matrix (genes × samples). Expected format: tab-delimited file with gene IDs as rows and sample IDs (TCGA barcodes) as columns. clinical_path: Path to clinical survival data. Expected format: tab-delimited file with survival information, stage annotations, and patient metadata. phenotype_path: Path to phenotype annotations. Expected format: tab-delimited file with sample type information (Primary Tumor, Metastatic, etc.). cancer_type: Cancer type abbreviation. Supported values: - "KIRC": Kidney Renal Clear Cell Carcinoma - "BRCA": Breast Invasive Carcinoma (filters to female patients only) output_dir: Directory where processed CSV files will be saved. early_late: If True, map AJCC stages to binary classification: - "early": Stage I and Stage II - "late": Stage III and Stage IV If False, retain original stage granularity (I, II, III, IV).

Returns: Tuple of Paths to processed files: - rnaseq_path: Processed RNA-seq expression matrix - clinical_path: Processed clinical annotations - phenotype_path: Processed phenotype data

Raises: FileNotFoundError: If input files do not exist KeyError: If required columns are missing from input data ValueError: If cancer_type is not supported

Notes: - For BRCA, only female patients with ductal or lobular carcinomas are retained - AJCC pathologic tumor stage is used as the primary staging system - Substages (A, B, C) are collapsed to main stages for statistical power - All three output datasets maintain consistent sample identifiers

Examples: >>> # Process KIRC data with 4-stage classification >>> paths = process_downloaded_data( ... rnaseq_path="data/raw/expression.xena", ... clinical_path="data/raw/survival.tsv", ... phenotype_path="data/raw/phenotype.tsv", ... cancer_type="KIRC", ... output_dir="data/processed", ... early_late=False ... )

>>> # Process KIRC data with binary early/late classification
>>> paths = process_downloaded_data(
...     cancer_type="KIRC",
...     output_dir="data/processed",
...     early_late=True
... )

References: - TCGA Research Network: https://www.cancer.gov/tcga - UCSC Xena Browser: https://xenabrowser.net/ - AJCC Cancer Staging Manual, 8th Edition

Source code in renalprog/dataset.py
def process_downloaded_data(
        rnaseq_path: Path = "data/raw/EB%2B%2BAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena",
        clinical_path: Path = "data/raw/Survival_SupplementalTable_S1_20171025_xena_sp",
        phenotype_path: Path = "data/raw/TCGA_phenotype_denseDataOnlyDownload.tsv",
        cancer_type: str = "KIRC",
        output_dir: Path = "data/raw",
        early_late: bool = False) -> Tuple[pd.DataFrame,pd.DataFrame,pd.DataFrame]:
    """
    Process TCGA Pan-Cancer Atlas data for a specific cancer type.

    This function performs comprehensive data processing and quality control on TCGA
    datasets, including cancer type filtering, sample type selection, stage harmonization,
    and optional binary classification mapping. The processing pipeline follows TCGA
    best practices for multi-omics data integration.

    Processing Steps:
        1. Load raw RNA-seq, clinical, and phenotype data from TCGA Xena Hub
        2. Filter samples by cancer type (KIRC or BRCA)
        3. Select primary tumor samples only (exclude metastatic/recurrent)
        4. Remove ambiguous stage annotations (Stage X, discrepancies, missing)
        5. Harmonize substages (e.g., Stage IA/IB/IC → Stage I)
        6. Optionally map stages to binary early/late classification
        7. Ensure consistency across all three datasets
        8. Save processed data in CSV format

    Args:
        rnaseq_path: Path to RNA-seq expression matrix (genes × samples).
            Expected format: tab-delimited file with gene IDs as rows and
            sample IDs (TCGA barcodes) as columns.
        clinical_path: Path to clinical survival data.
            Expected format: tab-delimited file with survival information,
            stage annotations, and patient metadata.
        phenotype_path: Path to phenotype annotations.
            Expected format: tab-delimited file with sample type information
            (Primary Tumor, Metastatic, etc.).
        cancer_type: Cancer type abbreviation. Supported values:
            - "KIRC": Kidney Renal Clear Cell Carcinoma
            - "BRCA": Breast Invasive Carcinoma (filters to female patients only)
        output_dir: Directory where processed CSV files will be saved.
        early_late: If True, map AJCC stages to binary classification:
            - "early": Stage I and Stage II
            - "late": Stage III and Stage IV
            If False, retain original stage granularity (I, II, III, IV).

    Returns:
        Tuple of Paths to processed files:
            - rnaseq_path: Processed RNA-seq expression matrix
            - clinical_path: Processed clinical annotations
            - phenotype_path: Processed phenotype data

    Raises:
        FileNotFoundError: If input files do not exist
        KeyError: If required columns are missing from input data
        ValueError: If cancer_type is not supported

    Notes:
        - For BRCA, only female patients with ductal or lobular carcinomas are retained
        - AJCC pathologic tumor stage is used as the primary staging system
        - Substages (A, B, C) are collapsed to main stages for statistical power
        - All three output datasets maintain consistent sample identifiers

    Examples:
        >>> # Process KIRC data with 4-stage classification
        >>> paths = process_downloaded_data(
        ...     rnaseq_path="data/raw/expression.xena",
        ...     clinical_path="data/raw/survival.tsv",
        ...     phenotype_path="data/raw/phenotype.tsv",
        ...     cancer_type="KIRC",
        ...     output_dir="data/processed",
        ...     early_late=False
        ... )

        >>> # Process KIRC data with binary early/late classification
        >>> paths = process_downloaded_data(
        ...     cancer_type="KIRC",
        ...     output_dir="data/processed",
        ...     early_late=True
        ... )

    References:
        - TCGA Research Network: https://www.cancer.gov/tcga
        - UCSC Xena Browser: https://xenabrowser.net/
        - AJCC Cancer Staging Manual, 8th Edition
    """
    # =========================================================================
    # STEP 1: Load raw data from TCGA sources
    # =========================================================================
    logger.info("="*80)
    logger.info(f"Starting data processing pipeline for {cancer_type}")
    logger.info("="*80)

    rnaseq = pd.read_table(rnaseq_path, index_col=0)
    clinical = pd.read_csv(clinical_path, sep="\t", index_col=0)
    pheno = pd.read_table(phenotype_path, sep="\t", index_col=0)

    # Remove redundant _PATIENT column from clinical data
    clinical.drop("_PATIENT", axis=1, inplace=True, errors='ignore')

    logger.info("\n[INITIAL DATA SHAPES]")
    logger.info(f"  RNA-seq:   {rnaseq.shape[0]:>6,} genes × {rnaseq.shape[1]:>5,} samples")
    logger.info(f"  Clinical:  {clinical.shape[0]:>6,} patients × {clinical.shape[1]:>3,} features")
    logger.info(f"  Phenotype: {pheno.shape[0]:>6,} samples × {pheno.shape[1]:>3,} features")

    # =========================================================================
    # STEP 2: Filter by cancer type
    # =========================================================================
    logger.info(f"\n[FILTERING BY CANCER TYPE: {cancer_type}]")

    if cancer_type == 'BRCA':
        # For breast cancer, only include female patients
        clinical = clinical[
            (clinical["gender"] == "FEMALE") &
            (clinical["cancer type abbreviation"] == cancer_type)
        ]
        logger.info(f"  Filter: Female patients with {cancer_type}")
    elif cancer_type == 'KIRC':
        clinical = clinical[clinical["cancer type abbreviation"] == cancer_type]
        logger.info(f"  Filter: Patients with {cancer_type}")
    else:
        logger.warning(f"  Cancer type '{cancer_type}' may not be fully supported")
        clinical = clinical[clinical["cancer type abbreviation"] == cancer_type]

    # Synchronize datasets to common samples
    pheno = pheno[pheno.index.isin(clinical.index)]
    rnaseq = rnaseq.loc[:, rnaseq.columns.isin(clinical.index)]
    pheno = pheno[pheno.index.isin(rnaseq.columns)]
    clinical = clinical[clinical.index.isin(rnaseq.columns)]

    logger.info(f"\n  Post-filter shapes:")
    logger.info(f"    RNA-seq:   {rnaseq.shape[0]:>6,} genes × {rnaseq.shape[1]:>5,} samples")
    logger.info(f"    Clinical:  {clinical.shape[0]:>6,} patients")
    logger.info(f"    Phenotype: {pheno.shape[0]:>6,} samples")

    # =========================================================================
    # STEP 3: Select primary tumor samples only
    # =========================================================================
    logger.info("\n[FILTERING BY SAMPLE TYPE: Primary Tumor]")

    # Log sample type distribution before filtering
    sample_type_counts = pheno["sample_type"].value_counts()
    logger.info("  Sample type distribution:")
    for sample_type, count in sample_type_counts.items():
        logger.info(f"    {sample_type:<30} {count:>5,} samples")

    pheno = pheno[pheno["sample_type"] == "Primary Tumor"]
    clinical = clinical[clinical.index.isin(pheno.index)]
    rnaseq = rnaseq.loc[:, rnaseq.columns.isin(pheno.index)]

    logger.info(f"\n  Retained {len(pheno):,} primary tumor samples")
    logger.info(f"    RNA-seq:   {rnaseq.shape[0]:>6,} genes × {rnaseq.shape[1]:>5,} samples")
    logger.info(f"    Clinical:  {clinical.shape[0]:>6,} patients")
    logger.info(f"    Phenotype: {pheno.shape[0]:>6,} samples")

    # =========================================================================
    # STEP 4: Remove ambiguous stage annotations
    # =========================================================================
    logger.info("\n[FILTERING BY STAGE QUALITY]")

    stages_remove = ['Stage X', '[Discrepancy]', np.nan]
    initial_stage_counts = clinical["ajcc_pathologic_tumor_stage"].value_counts(dropna=False)

    logger.info("  Initial stage distribution:")
    _log_stage_table(initial_stage_counts)

    # Filter out problematic stage annotations
    clinical_redux = clinical[~clinical["ajcc_pathologic_tumor_stage"].isin(stages_remove)]
    removed_count = len(clinical) - len(clinical_redux)
    logger.info(f"\n  Removed {removed_count} samples with ambiguous staging")

    # =========================================================================
    # STEP 5: Harmonize substages (collapse A/B/C to main stage)
    # =========================================================================
    logger.info("\n[HARMONIZING SUBSTAGES]")

    stages_available = clinical_redux["ajcc_pathologic_tumor_stage"].unique()
    has_substages = any(
        str(stage).endswith(("A", "B", "C"))
        for stage in stages_available
        if pd.notna(stage)
    )

    if has_substages:
        logger.info("  Substages detected (e.g., Stage IA, Stage IIB)")
        logger.info("  Collapsing substages to main stages for statistical power")

        # Collapse substages by removing trailing A/B/C letters
        stages_clump = [
            str(stage)[:-1] if str(stage).endswith(("A", "B", "C")) else stage
            for stage in clinical_redux["ajcc_pathologic_tumor_stage"]
        ]
        clinical_redux["ajcc_pathologic_tumor_stage"] = stages_clump

        post_clump_counts = clinical_redux["ajcc_pathologic_tumor_stage"].value_counts().sort_index()
        logger.info("\n  Stage distribution after harmonization:")
        _log_stage_table(post_clump_counts)
    else:
        logger.info("  No substages detected; stage labels are already harmonized")

    # Synchronize datasets after stage filtering
    pheno_redux = pheno[pheno.index.isin(clinical_redux.index)]
    rnaseq_redux = rnaseq.loc[:, rnaseq.columns.isin(clinical_redux.index)]

    # =========================================================================
    # STEP 6: Additional cancer-specific filtering (BRCA only)
    # =========================================================================
    if cancer_type == 'BRCA':
        logger.info("\n[BRCA-SPECIFIC FILTERING: Ductal and Lobular Carcinomas]")
        ductal_lobular = ['Infiltrating Ductal Carcinoma', 'Infiltrating Lobular Carcinoma']
        patients_ductal_lobular = clinical_redux.index[
            clinical_redux["histological_type"].isin(ductal_lobular)
        ]

        clinical_redux = clinical_redux[clinical_redux.index.isin(patients_ductal_lobular)]
        pheno_redux = pheno_redux[pheno_redux.index.isin(patients_ductal_lobular)]
        rnaseq_redux = rnaseq_redux.loc[:, rnaseq_redux.columns.isin(patients_ductal_lobular)]

        logger.info(f"  Retained {len(patients_ductal_lobular):,} ductal/lobular samples")

    # =========================================================================
    # STEP 7: Optional binary classification mapping (early vs. late)
    # =========================================================================
    if early_late:
        logger.info("\n[MAPPING TO BINARY CLASSIFICATION: Early vs. Late]")
        logger.info("  Mapping scheme:")
        logger.info("    Early: Stage I, Stage II")
        logger.info("    Late:  Stage III, Stage IV")

        mapped_stages = map_stages_to_early_late(clinical_redux["ajcc_pathologic_tumor_stage"])
        valid_mask = mapped_stages.notna()

        clinical_redux = clinical_redux.loc[valid_mask, :]
        clinical_redux["ajcc_pathologic_tumor_stage"] = mapped_stages[valid_mask]
        pheno_redux = pheno_redux[pheno_redux.index.isin(clinical_redux.index)]
        rnaseq_redux = rnaseq_redux.loc[:, rnaseq_redux.columns.isin(clinical_redux.index)]

        final_stage_counts = clinical_redux["ajcc_pathologic_tumor_stage"].value_counts().sort_index()
        logger.info("\n  Final binary classification distribution:")
        _log_stage_table(final_stage_counts)
    else:
        final_stage_counts = clinical_redux["ajcc_pathologic_tumor_stage"].value_counts().sort_index()
        logger.info("\n  Final stage distribution (4-class):")
        _log_stage_table(final_stage_counts)

    # =========================================================================
    # STEP 8: Final statistics and data saving
    # =========================================================================
    logger.info("\n[FINAL PROCESSED DATA SHAPES]")
    logger.info(f"  RNA-seq:   {rnaseq_redux.shape[0]:>6,} genes × {rnaseq_redux.shape[1]:>5,} samples")
    logger.info(f"  Clinical:  {clinical_redux.shape[0]:>6,} patients × {clinical_redux.shape[1]:>3,} features")
    logger.info(f"  Phenotype: {pheno_redux.shape[0]:>6,} samples × {pheno_redux.shape[1]:>3,} features")

    # Calculate and log filtering statistics
    initial_samples = rnaseq.shape[1]
    final_samples = rnaseq_redux.shape[1]
    retention_rate = (final_samples / initial_samples) * 100 if initial_samples > 0 else 0

    logger.info(f"\n[FILTERING SUMMARY]")
    logger.info(f"  Initial samples:  {initial_samples:>5,}")
    logger.info(f"  Final samples:    {final_samples:>5,}")
    logger.info(f"  Retention rate:   {retention_rate:>5.1f}%")
    logger.info(f"  Samples removed:  {initial_samples - final_samples:>5,}")

    # Save processed data to disk
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    path_rnaseq = output_dir / f"{cancer_type}_rnaseq.csv"
    path_clinical = output_dir / f"{cancer_type}_clinical.csv"
    path_phenotype = output_dir / f"{cancer_type}_phenotype.csv"

    logger.info(f"\n[SAVING PROCESSED DATA]")
    logger.info(f"  Output directory: {output_dir}")

    rnaseq_redux.to_csv(path_rnaseq)
    logger.info(f"  ✓ Saved RNA-seq:   {path_rnaseq.name}")

    clinical_redux.to_csv(path_clinical)
    logger.info(f"  ✓ Saved clinical:  {path_clinical.name}")

    pheno_redux.to_csv(path_phenotype)
    logger.info(f"  ✓ Saved phenotype: {path_phenotype.name}")

    logger.info("\n" + "="*80)
    logger.info(f"Data processing completed successfully for {cancer_type}")
    logger.info("="*80 + "\n")

    return rnaseq_redux, clinical_redux, pheno_redux

Example Usage:

from renalprog.dataset import process_downloaded_data
from pathlib import Path

# Process data for KIRC (Kidney Renal Clear Cell Carcinoma)
rnaseq, clinical, phenotype = process_downloaded_data(
    rnaseq_path=Path("data/raw/EB%2B%2BAdjustPANCAN_IlluminaHiSeq_RNASeqV2.geneExp.xena"),
    clinical_path=Path("data/raw/Survival_SupplementalTable_S1_20171025_xena_sp"),
    phenotype_path=Path("data/raw/TCGA_phenotype_denseDataOnlyDownload.tsv"),
    cancer_type="KIRC",
    output_dir=Path("data/raw"),
    early_late=False
)

print(f"RNA-seq shape: {rnaseq.shape}")
print(f"Clinical shape: {clinical.shape}")

load_rnaseq_data

Load RNA-seq expression data from a file.

load_rnaseq_data

load_rnaseq_data(path: Path) -> pd.DataFrame

Load RNA-seq data from CSV file.

Args: path: Path to RNA-seq CSV file (genes as rows, samples as columns)

Returns: DataFrame with RNA-seq data

Source code in renalprog/dataset.py
def load_rnaseq_data(path: Path) -> pd.DataFrame:
    """
    Load RNA-seq data from CSV file.

    Args:
        path: Path to RNA-seq CSV file (genes as rows, samples as columns)

    Returns:
        DataFrame with RNA-seq data
    """
    logger.info(f"Loading RNA-seq data from {path}")
    data = pd.read_csv(path, index_col=0)
    logger.info(f"Loaded RNA-seq data with shape: {data.shape}")
    return data

load_clinical_data

Load clinical and survival data from a file.

load_clinical_data

load_clinical_data(
    path: Path,
    stage_column: str = "ajcc_pathologic_tumor_stage",
    early_late=True,
) -> pd.Series

Load clinical metadata.

Args: path: Path to clinical data CSV file stage_column: Name of column containing stage information

Returns: Series with clinical stages indexed by sample ID

Source code in renalprog/dataset.py
def load_clinical_data(path: Path, stage_column: str = "ajcc_pathologic_tumor_stage", early_late = True) -> pd.Series:
    """
    Load clinical metadata.

    Args:
        path: Path to clinical data CSV file
        stage_column: Name of column containing stage information

    Returns:
        Series with clinical stages indexed by sample ID
    """
    logger.info(f"Loading clinical data from {path}")
    data = pd.read_csv(path, index_col=0)

    if stage_column not in data.columns:
        raise ValueError(f"Stage column '{stage_column}' not found in clinical data")

    stages = data[stage_column]
    logger.info(f"Loaded clinical data with {len(stages)} samples")

    if early_late:
        stages = map_stages_to_early_late(stages)
        logger.info("Mapped stages to binary early/late classification")
    logger.info(f"Stage distribution:\n{stages.value_counts()}")

    return stages

create_train_test_split

Create stratified train/test splits of the data.

create_train_test_split

create_train_test_split(
    rnaseq_path: Path,
    clinical_path: Path,
    test_size: float = 0.2,
    seed: int = 2023,
    use_onehot: bool = True,
    output_dir: Optional[Path] = None,
) -> Tuple[
    pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray, pd.DataFrame, pd.Series
]

Create stratified train/test split of KIRC data.

Args: rnaseq_path: Path to RNA-seq CSV file clinical_path: Path to clinical CSV file test_size: Fraction of data to use for testing (default: 0.2) seed: Random seed for reproducibility (default: 2023) use_onehot: Whether to one-hot encode the labels (default: True) output_dir: Optional directory to save split data

Returns: Tuple of (X_train, X_test, y_train, y_test, full_rnaseq, full_clinical)

Source code in renalprog/dataset.py
def create_train_test_split(
    rnaseq_path: Path,
    clinical_path: Path,
    test_size: float = 0.2,
    seed: int = 2023,
    use_onehot: bool = True,
    output_dir: Optional[Path] = None
) -> Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray, pd.DataFrame, pd.Series]:
    """
    Create stratified train/test split of KIRC data.

    Args:
        rnaseq_path: Path to RNA-seq CSV file
        clinical_path: Path to clinical CSV file
        test_size: Fraction of data to use for testing (default: 0.2)
        seed: Random seed for reproducibility (default: 2023)
        use_onehot: Whether to one-hot encode the labels (default: True)
        output_dir: Optional directory to save split data

    Returns:
        Tuple of (X_train, X_test, y_train, y_test, full_rnaseq, full_clinical)
    """
    set_seed(seed)

    # Load data
    rnaseq = load_rnaseq_data(rnaseq_path)
    clinical = load_clinical_data(clinical_path)

    # Ensure samples match between RNA-seq and clinical data
    common_samples = rnaseq.columns.intersection(clinical.index)
    if len(common_samples) < len(rnaseq.columns):
        logger.warning(
            f"Only {len(common_samples)} of {len(rnaseq.columns)} samples "
            f"have clinical data. Filtering to common samples."
        )

    rnaseq = rnaseq[common_samples]
    clinical = clinical[common_samples]

    # Transpose to have samples as rows
    rnaseq_t = rnaseq.T

    # Prepare labels for stratification
    if use_onehot:
        # Get unique stages in sorted order for consistent encoding
        categories = sorted(clinical.unique())
        ohe = OneHotEncoder(
            categories=[categories],
            handle_unknown='ignore',
            sparse_output=False,
            dtype=np.int8
        )
        y = ohe.fit_transform(clinical.values.reshape(-1, 1))
        logger.info(f"One-hot encoded labels with {y.shape[1]} classes")
    else:
        y = clinical.values

    # Perform stratified split
    X_train, X_test, y_train, y_test = train_test_split(
        rnaseq_t,
        y,
        test_size=test_size,
        stratify=y if use_onehot else clinical,
        random_state=seed
    )

    logger.info(f"Train set: {X_train.shape[0]} samples")
    logger.info(f"Test set: {X_test.shape[0]} samples")

    # Save split data
    if output_dir is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        # Save split data
        X_train.to_csv(output_dir / "X_train.csv")
        X_test.to_csv(output_dir / "X_test.csv")

        # Save labels WITH patient IDs as indices (matching X_train and X_test)
        if use_onehot:
            # Create DataFrame with patient IDs as index
            y_train_df = pd.DataFrame(y_train, index=X_train.index)
            y_test_df = pd.DataFrame(y_test, index=X_test.index)
            y_train_df.to_csv(output_dir / "y_train.csv")
            y_test_df.to_csv(output_dir / "y_test.csv")
        else:
            # Create Series with patient IDs as index
            y_train_series = pd.Series(y_train, index=X_train.index)
            y_test_series = pd.Series(y_test, index=X_test.index)
            y_train_series.to_csv(output_dir / "y_train.csv")
            y_test_series.to_csv(output_dir / "y_test.csv")

        # Save full data for reference
        rnaseq.to_csv(output_dir / "data.csv")
        clinical.to_csv(output_dir / "metadata.csv")

        # Save split statistics
        _save_split_statistics(clinical, y_train, y_test, output_dir, use_onehot, categories if use_onehot else None)

        logger.info(f"Saved train/test split to {output_dir}")

    return X_train, X_test, y_train, y_test, rnaseq, clinical

Example Usage:

from renalprog.dataset import load_rnaseq_data, load_clinical_data, create_train_test_split
from pathlib import Path

# Load the data
rnaseq = load_rnaseq_data(Path("data/raw/KIRC_rnaseq.tsv"))
clinical = load_clinical_data(Path("data/raw/KIRC_clinical.tsv"))

# Create train/test split
create_train_test_split(
    rnaseq=rnaseq,
    clinical=clinical,
    test_size=0.2,
    random_state=42,
    output_dir=Path("data/interim/my_experiment")
)

load_train_test_split

Load previously saved train/test split data.

load_train_test_split

load_train_test_split(
    split_dir: Path,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]

Load previously saved train/test split.

Args: split_dir: Directory containing saved split files

Returns: Tuple of (X_train, X_test, y_train, y_test)

Source code in renalprog/dataset.py
def load_train_test_split(split_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Load previously saved train/test split.

    Args:
        split_dir: Directory containing saved split files

    Returns:
        Tuple of (X_train, X_test, y_train, y_test)
    """
    split_dir = Path(split_dir)

    X_train = pd.read_csv(split_dir / "X_train.csv", index_col=0)
    X_test = pd.read_csv(split_dir / "X_test.csv", index_col=0)
    y_train = pd.read_csv(split_dir / "y_train.csv", index_col=0)
    y_test = pd.read_csv(split_dir / "y_test.csv", index_col=0)

    logger.info(f"Loaded train/test split from {split_dir}")
    logger.info(f"Train: {X_train.shape}, Test: {X_test.shape}")

    return X_train, X_test, y_train, y_test

Example Usage:

from renalprog.dataset import load_train_test_split
from pathlib import Path

# Load existing split
train_expr, test_expr, train_clin, test_clin = load_train_test_split(
    Path("data/interim/my_experiment")
)

print(f"Training samples: {len(train_expr)}")
print(f"Test samples: {len(test_expr)}")

map_stages_to_early_late

Map cancer stages to binary early/late categories.

map_stages_to_early_late

map_stages_to_early_late(stages: Series) -> pd.Series

Map detailed stages (I, II, III, IV) to binary early/late classification.

Args: stages: Series with stage labels (e.g., "Stage I", "Stage II", etc.)

Returns: Series with mapped stages ("early" or "late")

Source code in renalprog/dataset.py
def map_stages_to_early_late(stages: pd.Series) -> pd.Series:
    """
    Map detailed stages (I, II, III, IV) to binary early/late classification.

    Args:
        stages: Series with stage labels (e.g., "Stage I", "Stage II", etc.)

    Returns:
        Series with mapped stages ("early" or "late")
    """
    stage_mapping = PreprocessingConfig.STAGE_MAPPING
    mapped_stages = stages.map(stage_mapping)

    # Check for unmapped values
    unmapped = mapped_stages.isna() & stages.notna()
    if unmapped.any():
        logger.warning(f"Found {unmapped.sum()} unmapped stage values:")
        logger.warning(stages[unmapped].unique())

    return mapped_stages

Data Format Requirements

RNA-seq Expression Data

Expected format: - Rows: Genes (with gene symbols or Ensembl IDs) - Columns: Samples (patient IDs) - Values: Log2-transformed TPM or FPKM expression values

# Example structure
#              TCGA-A1-A0SB  TCGA-A1-A0SD  TCGA-A1-A0SE  ...
# GENE_A       5.234         4.891         6.123         ...
# GENE_B       2.456         2.789         2.634         ...
# GENE_C       8.912         9.234         8.756         ...

Clinical Data

Expected columns: - sample: Patient ID matching expression columns - OS: Overall survival status (0=alive, 1=deceased) - OS.time: Overall survival time (days)

Optional columns: - age_at_initial_pathologic_diagnosis: Age at diagnosis - gender: Patient gender - tumor_stage: Tumor stage

# Example structure
#    sample          OS  OS.time  age  gender  stage
# 0  TCGA-A1-A0SB    1   1825     65   MALE    IV
# 1  TCGA-A1-A0SD    0   2190     58   FEMALE  II

Train/Test Splitting

The module provides stratified splitting to maintain class balance:

from renalprog.dataset import (
    load_rnaseq_data, 
    load_clinical_data, 
    create_train_test_split
)
from pathlib import Path

# Load the data
rnaseq = load_rnaseq_data(Path("data/raw/KIRC_rnaseq.tsv"))
clinical = load_clinical_data(Path("data/raw/KIRC_clinical.tsv"))

# Create stratified split preserving early/late survival distribution
create_train_test_split(
    rnaseq=rnaseq,
    clinical=clinical,
    test_size=0.2,  # 20% test set
    random_state=42,  # For reproducibility
    output_dir=Path("data/interim/my_split")
)

# Load the split data
train_expr, test_expr, train_clin, test_clin = load_train_test_split(
    Path("data/interim/my_split")
)

# Check class distribution
import pandas as pd
train_dist = train_clin.value_counts(normalize=True)
test_dist = test_clin.value_counts(normalize=True)

print("Training set distribution:")
print(train_dist)
print("\nTest set distribution:")
print(test_dist)

Data Preprocessing Pipeline

The standard preprocessing pipeline:

  1. Download raw data from TCGA
  2. Filter by cancer type (e.g., KIRC)
  3. Filter low expression genes (see Features API)
  4. Remove outliers using Mahalanobis distance
  5. Create train/test split with stratification
  6. Normalize using MinMaxScaler (0-1 range)
  7. Save preprocessed data for modeling

Complete Example:

from pathlib import Path
from renalprog.dataset import (
    download_data, 
    process_downloaded_data,
    load_rnaseq_data,
    load_clinical_data,
    create_train_test_split,
    load_train_test_split
)
from renalprog.features import filter_low_expression, detect_outliers_mahalanobis

# Step 1: Download
rnaseq_path, clinical_path, phenotype_path = download_data(
    destination=Path("data/raw")
)

# Step 2: Process for KIRC
rnaseq, clinical, _ = process_downloaded_data(
    rnaseq_path=rnaseq_path,
    clinical_path=clinical_path,
    phenotype_path=phenotype_path,
    cancer_type="KIRC",
    output_dir=Path("data/raw")
)

# Step 3: Filter low expression
rnaseq_filtered = filter_low_expression(
    rnaseq,
    mean_threshold=0.5,
    var_threshold=0.5
)

# Step 4: Remove outliers
rnaseq_clean, outliers, _ = detect_outliers_mahalanobis(
    rnaseq_filtered,
    alpha=0.05
)

# Step 5: Create train/test split
rnaseq_clean_path = Path("data/interim/rnaseq_clean.csv")
clinical_path = Path("data/raw/KIRC_clinical.tsv")
rnaseq_clean.to_csv(rnaseq_clean_path)

# Load and split
rnaseq_final = load_rnaseq_data(rnaseq_clean_path)
clinical_final = load_clinical_data(clinical_path)

create_train_test_split(
    rnaseq=rnaseq_final,
    clinical=clinical_final,
    test_size=0.2,
    random_state=42,
    output_dir=Path("data/interim/20251218_experiment")
)

See Also