Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 70 additions & 59 deletions tests/pytorch/distributed/run_fsdp2_fp8_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,48 @@
from pathlib import Path

class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size, use_fsdp2=False):
def __init__(self, input_size, hidden_size, output_size, use_fsdp2=False, linear_only=False):
super(SimpleNet, self).__init__()

# LayerNormLinear: fuses LayerNorm + Linear
self.ln_linear = te.LayerNormLinear(
in_features=input_size,
out_features=hidden_size,
eps=1e-5,
use_fsdp2=use_fsdp2,
keep_fp8_weight_transpose_cache=False
)

# LayerNormMLP: fuses LayerNorm + FC1 + Activation + FC2
self.ln_mlp = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4, # Typical 4x expansion
use_fsdp2=use_fsdp2,
keep_fp8_weight_transpose_cache=False
)

self.linear_only = linear_only
if not linear_only:
# LayerNormLinear: fuses LayerNorm + Linear
self.ln_linear = te.LayerNormLinear(
in_features=input_size,
out_features=hidden_size,
eps=1e-5,
use_fsdp2=use_fsdp2,
keep_fp8_weight_transpose_cache=False
)

# LayerNormMLP: fuses LayerNorm + FC1 + Activation + FC2
self.ln_mlp = te.LayerNormMLP(
hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4, # Typical 4x expansion
use_fsdp2=use_fsdp2,
keep_fp8_weight_transpose_cache=False
)

# Regular Linear for final projection
self.fc_out = te.Linear(
hidden_size,
output_size,
use_fsdp2=use_fsdp2,
keep_fp8_weight_transpose_cache=False
keep_fp8_weight_transpose_cache=False,
bias=False
)

def forward(self, x):
# LayerNormLinear: applies LayerNorm then Linear
x = self.ln_linear(x)

# LayerNormMLP: applies LayerNorm + FC1 + GELU + FC2
x = self.ln_mlp(x)

# Final Linear projection
x = self.fc_out(x)
if self.linear_only:
return self.fc_out(x)
else:
# LayerNormLinear: applies LayerNorm then Linear
x = self.ln_linear(x)

# LayerNormMLP: applies LayerNorm + FC1 + GELU + FC2
x = self.ln_mlp(x)

# Final Linear projection
x = self.fc_out(x)

return x

Expand All @@ -86,8 +91,12 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--linear-only", action="store_true", default=False, help="Only use Linear layer")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
"--quantized-init", action="store_true", default=False, help="Initialize primary weights in FP8 via quantized_model_init."
)
parser.add_argument(
"--autocast", action="store_true", default=False, help="Enable te.autocast for FP8 compute."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
Expand Down Expand Up @@ -169,15 +178,30 @@ def _train(args):

if args.memory_profile:
torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
if args.fp8_init:
# Build the model with the specified context
with quantized_model_init(enabled=True):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
# Move the model to the correct device
if not args.memory_profile:
model.load_state_dict(torch.load('fsdp_model.pth'))

prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()

# Build the model with the specified context
with quantized_model_init(enabled=args.quantized_init, recipe=fp8_recipe):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2, linear_only=args.linear_only)
model.to(device)

# Creating a DeviceMesh for fully_shard
Expand Down Expand Up @@ -215,7 +239,10 @@ def _train(args):
else:
model = DDP(model, device_ids=[LOCAL_RANK])

optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)
if args.quantized_init:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3, master_weights=True)
else:
optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)

input_path = Path("shared_input.pt")
if input_path.exists():
Expand All @@ -226,25 +253,6 @@ def _train(args):
print("Generated and saved shared input tensor.")

out_tensors = []
prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
):
prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
for iteration in range(args.iter):
if LOCAL_RANK == 0:
print(f"Starting iteration...{iteration}")
Expand All @@ -253,7 +261,7 @@ def _train(args):

# Zero the parameter gradients
optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. I'll make the changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

with te.autocast(enabled=args.autocast, recipe=fp8_recipe):
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
Expand Down Expand Up @@ -286,6 +294,9 @@ def _train(args):
torch.save(out_tensors, args.gradients_save_file)

if args.memory_profile:
with open('memory_summary.txt', 'w') as f:
f.write(torch.cuda.memory_summary(device=None, abbreviated=False))

snapshot = torch.cuda.memory._snapshot()
import pickle
with open('memory_snapshot.pickle', 'wb') as f:
Expand Down
110 changes: 80 additions & 30 deletions tests/pytorch/distributed/test_torch_fsdp2_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,126 @@
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
import torch
from run_fsdp2_fp8_model import SimpleNet

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()

NUM_PROCS: int = torch.cuda.device_count()

def assertEqual(
l1: List[torch.Tensor], l2: List[torch.Tensor]) -> bool:
"""Ensures two lists are exactly equal."""
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
tols = dict(atol=atol)
tols["rtol"] = rtol if rtol is not None else 0
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=0, rtol=0)
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
exceed_mask = diff > 0
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
if diff.dim() == 0:
max_diff = diff
max_location = []
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
f"Outputs not close enough in scalar tensor at idx={i}. "
f"Difference: {max_diff.item()}."
)
else:
exceed_mask = diff > tol

if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)

def _run_test(fp_init, recipe):
def _run_test(quantized_init, autocast, recipe):
test_dir = Path(__file__).parent.resolve()
fsdp_script = test_dir / "run_fsdp2_fp8_model.py"

test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)]

if fp_init:
test_cmd += ["--fp8-init"]
test_cmd += ["--recipe", recipe]
if quantized_init:
test_cmd += ["--quantized-init"]
if autocast:
test_cmd += ["--autocast"]
if autocast or quantized_init:
test_cmd += ["--recipe", recipe]
if quantized_init:
test_cmd += ["--linear-only"]

subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True)
subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True)

# Load outputs
output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu")
output_dp = torch.load("all_iters_dp.pt", map_location="cpu")
atol = 0
rtol = 0
# Use relaxed tolerance when FSDP2 and DDP are not guaranteed to be bit-identical:
#
# - No FP8 (quantized_init=False, autocast=False): gradient reduction order differs
# (all-reduce vs reduce-scatter), so float non-associativity produces last-bit
# differences in the reduced gradients and updated weights.
#
# quantized_init=True + autocast=True uses a Linear-only model (--linear-only)
# with bias=False to ensure all parameters are Float8Tensors. This avoids a
# known issue where PyTorch DDP's _broadcast_coalesced concatenates FP8 and
# FP32 parameters via aten::cat, triggering Float8Tensor dequantization (since
# Float8Tensor doesn't natively handle aten::cat). The subsequent aten::copy_
# back into the Float8Tensor re-quantizes from the dequantized values, which
# recomputes amax/scale with FP8 round-trip error, causing divergence from
# FSDP2. With bias=False and all params in FP8, aten::cat operates on
# homogeneous Float8Tensors and no dequantization occurs.
#
# autocast=True alone (quantized_init=False) works without any modifications
# because weights are initialized as regular FP32 tensors. DDP broadcasts
# FP32 parameters natively without any FP8 dequantize/re-quantize path, and
# quantization to FP8 only happens dynamically during the forward pass inside
# te.autocast, which is identical for both DDP and FSDP2.
if (not quantized_init and not autocast):
atol = 1e-6
rtol = 5e-5
Comment thread
wangye805 marked this conversation as resolved.

for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)):

print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...")
assertEqual(te_output_no_cache[1], te_output_cache[1]) # expects exact match
assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=atol, rtol=rtol)
print(f"Tensor at index {idx} passed comparison.")


@pytest.fixture
def cleanup_artifacts():
yield # run the test first
for fname in ["all_iters_fsdp2.pt", "all_iters_dp.pt", "fsdp_model.pth", "shared_input.pt"]:
for fname in ["all_iters_fsdp2.pt", "all_iters_dp.pt", "shared_input.pt"]:
if os.path.exists(fname):
os.remove(fname)

# Define test cases explicitly
test_cases = []
for quantized_init in [True, False]:
for autocast in [True, False]:
if quantized_init and not autocast:
continue
if quantized_init or autocast:
for recipe in ["delayed", "current", "mxfp8"]:
test_cases.append((quantized_init, autocast, recipe))
test_cases.append((False, False, "delayed"))


@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("fp8_init", ([False]))
@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"]))
@pytest.mark.parametrize("quantized_init, autocast, recipe", test_cases)
@pytest.mark.usefixtures("cleanup_artifacts")
def test_distributed(fp8_init, recipe):
def test_distributed(quantized_init, autocast, recipe):

batch_size = 2048
input_size = 2048
Expand All @@ -90,18 +143,15 @@ def test_distributed(fp8_init, recipe):
torch.save(input_data.cpu(), input_path)
print("Generated and saved shared input tensor.")

model = SimpleNet(input_size, 2048, 2048)
torch.save(model.state_dict(), 'fsdp_model.pth')

if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")

if fp8_init and not fp8_available:
if quantized_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

_run_test(fp8_init, recipe)
_run_test(quantized_init, autocast, recipe)


def test_dummy() -> None:
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from torch.distributed.tensor import DTensor

import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
Expand Down Expand Up @@ -1053,7 +1054,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration

if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2:
if IS_HIP_EXTENSION and not self.primary_weights_in_fp8 and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2:
FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True

if self.fp8_parameters or fp8_enabled:
Expand Down Expand Up @@ -1358,9 +1359,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
self.keep_fp8_weight_transpose_cache = False
param = FSDPAGTensor(
param,
module=self,
fp8_meta_index=fp8_meta_index,
keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache
fp8_meta_index=fp8_meta_index,
)

# Redo parameter wrap in case we broke it above
Expand Down
Loading