Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def _run_cli(args: argparse.Namespace) -> None:
feature_dir=config.deployment.feature_dir,
patient_label=config.deployment.patient_label,
filename_label=config.deployment.filename_label,
drop_patients_with_missing_ground_truth=(
config.deployment.drop_patients_with_missing_ground_truth
),
num_workers=config.deployment.num_workers,
accelerator=config.deployment.accelerator,
ground_truth_label=config.deployment.ground_truth_label,
Expand Down
6 changes: 6 additions & 0 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ crossval:
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# drop_patients_with_missing_ground_truth: true

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
# time_label: "time"
Expand Down Expand Up @@ -135,6 +137,8 @@ training:
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# drop_patients_with_missing_ground_truth: true

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
# time_label: "time"
Expand Down Expand Up @@ -179,6 +183,8 @@ deployment:
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# drop_patients_with_missing_ground_truth: true

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
# time_label: "time"
Expand Down
13 changes: 13 additions & 0 deletions src/stamp/modeling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from stamp.modeling.registry import ModelName
from stamp.types import Category, PandasLabel, Task

_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION = (
"If true, only patients present in the clinical table are included. "
"Set to false to keep patients without ground truth when the task supports it."
)


class TrainConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -36,6 +41,10 @@ class TrainConfig(BaseModel):
default=None,
description="Column in the clinical table indicating follow-up or survival time (e.g. days).",
)
drop_patients_with_missing_ground_truth: bool = Field(
default=True,
description=_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION,
)

patient_label: PandasLabel = "PATIENT"
filename_label: PandasLabel = "FILENAME"
Expand Down Expand Up @@ -71,6 +80,10 @@ class DeploymentConfig(BaseModel):
# For survival prediction
status_label: PandasLabel | None = None
time_label: PandasLabel | None = None
drop_patients_with_missing_ground_truth: bool = Field(
default=True,
description=_DROP_PATIENTS_WITH_MISSING_GROUND_TRUTH_DESCRIPTION,
)

num_workers: int = min(os.cpu_count() or 1, 16)
accelerator: str = "gpu" if torch.cuda.is_available() else "cpu"
Expand Down
4 changes: 3 additions & 1 deletion src/stamp/modeling/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def categorical_crossval_(
status_label=config.status_label,
patient_label=config.patient_label,
filename_label=config.filename_label,
drop_patients_with_missing_ground_truth=True,
drop_patients_with_missing_ground_truth=(
config.drop_patients_with_missing_ground_truth
),
)
_logger.info(f"Detected feature type: {feature_type}")

Expand Down
3 changes: 2 additions & 1 deletion src/stamp/modeling/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def deploy_categorical_model_(
filename_label: PandasLabel,
num_workers: int,
accelerator: str | Accelerator,
drop_patients_with_missing_ground_truth: bool = True,
) -> None:
"""Deploy categorical model(s) and save predictions.

Expand Down Expand Up @@ -230,7 +231,7 @@ def deploy_categorical_model_(
patient_to_ground_truth,
),
slide_to_patient=slide_to_patient,
drop_patients_with_missing_ground_truth=False,
drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth,
)

patient_ids = list(patient_to_data.keys())
Expand Down
4 changes: 3 additions & 1 deletion src/stamp/modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def train_categorical_model_(
status_label=config.status_label,
patient_label=config.patient_label,
filename_label=config.filename_label,
drop_patients_with_missing_ground_truth=True,
drop_patients_with_missing_ground_truth=(
config.drop_patients_with_missing_ground_truth
),
)
_logger.info(f"Detected feature type: {feature_type}")

Expand Down
30 changes: 19 additions & 11 deletions src/stamp/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,27 @@ def compute_stats_(
)

preds_dfs = [
_read_table(
p,
usecols=[
ground_truth_label,
f"{ground_truth_label}_{true_class}",
],
dtype={
ground_truth_label: str,
f"{ground_truth_label}_{true_class}": float,
},
)
df
for p in pred_csvs
if len(
df := _read_table(
p,
usecols=[
ground_truth_label,
f"{ground_truth_label}_{true_class}",
],
dtype={
ground_truth_label: str,
f"{ground_truth_label}_{true_class}": float,
},
).dropna(subset=[ground_truth_label])
)
> 0
]
if not preds_dfs:
raise ValueError(
"No classification rows with ground truth available for plotting."
)

y_trues = [
np.array(df[ground_truth_label] == true_class) for df in preds_dfs
Expand Down
16 changes: 10 additions & 6 deletions src/stamp/statistics/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,17 @@ def categorical_aggregated_(
calculate the mean and 95% confidence interval for all the scores as
well as sum the total instane count for each class.
"""
preds_dfs = {
Path(p).parent.name: _categorical(
pd.read_csv(p, dtype=str).dropna(subset=[ground_truth_label]),
ground_truth_label,
preds_dfs = {}
for p in preds_csvs:
df = pd.read_csv(p, dtype=str).dropna(subset=[ground_truth_label])
if len(df) > 0:
preds_dfs[Path(p).parent.name] = _categorical(df, ground_truth_label)

if not preds_dfs:
raise ValueError(
"No classification rows with ground truth available for statistics."
)
for p in preds_csvs
}

preds_df = pd.concat(preds_dfs).sort_index()
preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv")
stats_df = _aggregate_categorical_stats(preds_df.reset_index())
Expand Down
Loading