diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 0b252d6f..2e1e1a8c 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -183,6 +183,9 @@ def _run_cli(args: argparse.Namespace) -> None: feature_dir=config.deployment.feature_dir, patient_label=config.deployment.patient_label, filename_label=config.deployment.filename_label, + drop_patients_with_missing_ground_truth=( + config.deployment.drop_patients_with_missing_ground_truth + ), num_workers=config.deployment.num_workers, accelerator=config.deployment.accelerator, ground_truth_label=config.deployment.ground_truth_label, diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 7ebdf9d6..3e50b986 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -79,6 +79,8 @@ crossval: # For multi-target classification you may specify a list of columns, # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] + # drop_patients_with_missing_ground_truth: true + # For survival (should be status and follow-up days columns in clini table) # status_label: "event" # time_label: "time" @@ -135,6 +137,8 @@ training: # For multi-target classification you may specify a list of columns, # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] + # drop_patients_with_missing_ground_truth: true + # For survival (should be status and follow-up days columns in clini table) # status_label: "event" # time_label: "time" @@ -179,6 +183,8 @@ deployment: # For multi-target classification you may specify a list of columns, # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] + # drop_patients_with_missing_ground_truth: true + # For survival (should be status and follow-up days columns in clini table) # status_label: "event" # time_label: "time" diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 5b9a6bcc..e10f1899 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -8,6 +8,11 @@ from stamp.modeling.registry import ModelName from stamp.types import Category, PandasLabel, Task +_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION = ( + "If true, only patients present in the clinical table are included. " + "Set to false to keep patients without ground truth when the task supports it." +) + class TrainConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -36,6 +41,10 @@ class TrainConfig(BaseModel): default=None, description="Column in the clinical table indicating follow-up or survival time (e.g. days).", ) + drop_patients_with_missing_ground_truth: bool = Field( + default=True, + description=_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION, + ) patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -71,6 +80,10 @@ class DeploymentConfig(BaseModel): # For survival prediction status_label: PandasLabel | None = None time_label: PandasLabel | None = None + drop_patients_with_missing_ground_truth: bool = Field( + default=True, + description=_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION, + ) num_workers: int = min(os.cpu_count() or 1, 16) accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 4ee71563..36770fa4 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -64,7 +64,9 @@ def categorical_crossval_( status_label=config.status_label, patient_label=config.patient_label, filename_label=config.filename_label, - drop_patients_with_missing_ground_truth=True, + drop_patients_with_missing_ground_truth=( + config.drop_patients_with_missing_ground_truth + ), ) _logger.info(f"Detected feature type: {feature_type}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 6d81a8fc..8f7dcfe1 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -72,6 +72,7 @@ def deploy_categorical_model_( filename_label: PandasLabel, num_workers: int, accelerator: str | Accelerator, + drop_patients_with_missing_ground_truth: bool = True, ) -> None: """Deploy categorical model(s) and save predictions. @@ -230,7 +231,7 @@ def deploy_categorical_model_( patient_to_ground_truth, ), slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=False, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, ) patient_ids = list(patient_to_data.keys()) diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index 61944624..81287908 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -63,7 +63,9 @@ def train_categorical_model_( status_label=config.status_label, patient_label=config.patient_label, filename_label=config.filename_label, - drop_patients_with_missing_ground_truth=True, + drop_patients_with_missing_ground_truth=( + config.drop_patients_with_missing_ground_truth + ), ) _logger.info(f"Detected feature type: {feature_type}") diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index b3243ecc..0a7eedef 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -229,19 +229,27 @@ def compute_stats_( ) preds_dfs = [ - _read_table( - p, - usecols=[ - ground_truth_label, - f"{ground_truth_label}_{true_class}", - ], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) + df for p in pred_csvs + if len( + df := _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ).dropna(subset=[ground_truth_label]) + ) + > 0 ] + if not preds_dfs: + raise ValueError( + "No classification rows with ground truth available for plotting." + ) y_trues = [ np.array(df[ground_truth_label] == true_class) for df in preds_dfs diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index e19b1659..a267f5ca 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -125,13 +125,17 @@ def categorical_aggregated_( calculate the mean and 95% confidence interval for all the scores as well as sum the total instane count for each class. """ - preds_dfs = { - Path(p).parent.name: _categorical( - pd.read_csv(p, dtype=str).dropna(subset=[ground_truth_label]), - ground_truth_label, + preds_dfs = {} + for p in preds_csvs: + df = pd.read_csv(p, dtype=str).dropna(subset=[ground_truth_label]) + if len(df) > 0: + preds_dfs[Path(p).parent.name] = _categorical(df, ground_truth_label) + + if not preds_dfs: + raise ValueError( + "No classification rows with ground truth available for statistics." ) - for p in preds_csvs - } + preds_df = pd.concat(preds_dfs).sort_index() preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index())