Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
from .utils import (
DiffusionPrepareBatch,
GradientAccumulation,
IterationEvents,
PrepareBatch,
PrepareBatchDefault,
Expand Down
120 changes: 120 additions & 0 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
"GradientAccumulation",
]


Expand Down Expand Up @@ -360,3 +361,122 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:

"""
return current_metric > prev_best


def _noop(*args: Any, **kwargs: Any) -> None:
"""No-op callable used to suppress optimizer/scaler methods during gradient accumulation."""


class GradientAccumulation:
"""
Callable class implementing gradient accumulation for use with ``SupervisedTrainer``.

Gradients are accumulated over ``accumulation_steps`` mini-batches before calling
``optimizer.step()``, simulating a larger effective batch size on memory-constrained
hardware.

Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``::

trainer = SupervisedTrainer(
...,
iteration_update=GradientAccumulation(accumulation_steps=4),
)

All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``,
``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so
existing handlers (checkpoint savers, metric loggers, etc.) are unaffected.

When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch
even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently
discarded. For iterable datasets (``epoch_length=None``) this flush does not apply.

The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled**
original loss value, so metrics and loggers report the true loss. Internally
the loss is divided by ``accumulation_steps`` for the backward pass only.

Args:
accumulation_steps: number of mini-batches to accumulate before updating
weights. Must be a positive integer. Default: 2.

Raises:
ValueError: when ``accumulation_steps`` is not a positive integer.
"""

def __init__(self, accumulation_steps: int = 2) -> None:
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
Comment on lines +406 to +407
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject boolean values for accumulation_steps.

True currently passes validation because bool is an int subclass, so invalid config can silently map to 1.

Proposed fix
-        if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+        if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
             raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 407-407: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 406 - 407, The validation for
accumulation_steps currently allows booleans because bool is an int subclass;
update the check in monai.engines.utils (the accumulation_steps validation) to
explicitly reject bools — e.g., require type(accumulation_steps) is int or add
"and not isinstance(accumulation_steps, bool)" to the isinstance check — and
keep the existing lower-bound check (accumulation_steps < 1) so True/False no
longer pass validation.

self.accumulation_steps = accumulation_steps

def __repr__(self) -> str:
return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"

def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict:
"""
Execute one iteration with gradient accumulation.

Args:
engine: the Ignite engine (usually ``SupervisedTrainer``).
batchdata: batch data for this iteration.

Returns:
the output dict from ``engine._iteration()``.
"""
acc = self.accumulation_steps

result: dict

if acc == 1:
result = engine._iteration(engine, batchdata)
return result

# engine.state.iteration is 1-indexed and already incremented before __call__
epoch_length = engine.state.epoch_length # None for iterable datasets
if epoch_length is not None:
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
else:
local_iter = engine.state.iteration - 1 # 0-indexed global
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0

# Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle.
original_zero_grad = engine.optimizer.zero_grad
if not should_zero_grad:
engine.optimizer.zero_grad = _noop

# Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch
# gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients
# will average to the same value they would with the full batch.
original_loss_fn = engine.loss_function
engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc

# Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle.
# Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training.
original_step = engine.optimizer.step
original_scaler_step = None
original_scaler_update = None
if not should_step:
engine.optimizer.step = _noop
if hasattr(engine, "scaler") and engine.scaler is not None:
original_scaler_step = engine.scaler.step
original_scaler_update = engine.scaler.update
engine.scaler.step = _noop
engine.scaler.update = _noop

try:
result = engine._iteration(engine, batchdata)
finally:
engine.optimizer.zero_grad = original_zero_grad
engine.loss_function = original_loss_fn
engine.optimizer.step = original_step
if original_scaler_step is not None:
engine.scaler.step = original_scaler_step
engine.scaler.update = original_scaler_update

# Restore the unscaled loss for logging and metrics. The backward pass
# already used the scaled value, so this only affects what handlers see.
if CommonKeys.LOSS in result:
result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc

return result
Loading
Loading