Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

### Added

- Add deterministic execution mode for atomic operations via `wp.config.deterministic = True`.
Floating-point atomic accumulations use a scatter-sort-reduce strategy for bit-exact
reproducibility across runs. Counter/allocator atomics (where the return value is used)
use automatic two-pass execution with prefix-sum-based slot assignment. Configurable at
the global, module, and kernel level.
Comment on lines +7 to +11
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Changelog entry should include a GH reference and stay API-level.

Lines 7-11 describe internal mechanics but do not include an issue/PR reference.

📝 Suggested rewrite
-- Add deterministic execution mode for atomic operations via `wp.config.deterministic = True`.
-  Floating-point atomic accumulations use a scatter-sort-reduce strategy for bit-exact
-  reproducibility across runs. Counter/allocator atomics (where the return value is used)
-  use automatic two-pass execution with prefix-sum-based slot assignment. Configurable at
-  the global, module, and kernel level.
+- Add deterministic atomic execution mode via `wp.config.deterministic = True`, with global, module, and kernel-level control for reproducible results across CUDA runs ([GH-1355](https://github.com/NVIDIA/warp/pull/1355)).

As per coding guidelines: "If a change modifies user-facing behavior, append an entry ... include issue refs ... and avoid internal implementation details."

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- Add deterministic execution mode for atomic operations via `wp.config.deterministic = True`.
Floating-point atomic accumulations use a scatter-sort-reduce strategy for bit-exact
reproducibility across runs. Counter/allocator atomics (where the return value is used)
use automatic two-pass execution with prefix-sum-based slot assignment. Configurable at
the global, module, and kernel level.
- Add deterministic atomic execution mode via `wp.config.deterministic = True`, with global, module, and kernel-level control for reproducible results across CUDA runs ([GH-1355](https://github.com/NVIDIA/warp/pull/1355)).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@CHANGELOG.md` around lines 7 - 11, Update the CHANGELOG entry (mentioning
wp.config.deterministic) to be API-level and include a GitHub issue/PR
reference: remove internal implementation details like "scatter-sort-reduce" and
"two-pass execution with prefix-sum-based slot assignment", instead describe the
user-visible change (e.g., "Added a deterministic execution mode for atomic
operations via wp.config.deterministic = True that makes atomic accumulations
reproducible across runs"), append a short note about scope
(global/module/kernel) and add the GH issue/PR number (e.g., "See `#1234`") and
affected version/release tag.

- Add double-precision (`wp.float64`) support to `warp.fem`.
Precision is selected via the geometry (e.g. `scalar_type=wp.float64` on grid constructors)
and propagated automatically to function spaces, quadrature, fields, and integration kernels
Expand Down
220 changes: 219 additions & 1 deletion asv/benchmarks/atomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmarks for atomic operations under high thread contention.
"""Benchmarks for atomic operations and deterministic mode overhead.

All threads write to a single output location (index 0) to maximize contention
and measure worst-case atomic operation performance.
Expand All @@ -25,13 +25,18 @@

import warp as wp

wp.set_module_options({"enable_backward": False})

# Map string parameter names to warp dtypes
DTYPE_MAP = {
"float32": wp.float32,
"int32": wp.int32,
}

NUM_ELEMENTS = 32 * 1024 * 1024
DETERMINISTIC_NUM_ELEMENTS = 1 * 1024 * 1024
COUNTER_NUM_ELEMENTS = 4 * 1024 * 1024
DETERMINISTIC_BENCHMARK_SIZES = [64 * 1024, 256 * 1024, 1024 * 1024]


@wp.kernel
Expand All @@ -54,6 +59,60 @@ def min_kernel(
wp.atomic_min(out, 0, val) # All threads contend on out[0]


@wp.kernel
def scatter_add_kernel(
vals: wp.array(dtype=wp.float32),
indices: wp.array(dtype=wp.int32),
out: wp.array(dtype=wp.float32),
):
tid = wp.tid()
wp.atomic_add(out, indices[tid], vals[tid])


@wp.kernel(deterministic=True, deterministic_max_records=1)
def scatter_add_kernel_deterministic(
vals: wp.array(dtype=wp.float32),
indices: wp.array(dtype=wp.int32),
out: wp.array(dtype=wp.float32),
):
tid = wp.tid()
wp.atomic_add(out, indices[tid], vals[tid])


@wp.kernel
def counter_kernel(
vals: wp.array(dtype=wp.float32),
counter: wp.array(dtype=wp.int32),
out: wp.array(dtype=wp.float32),
):
tid = wp.tid()
slot = wp.atomic_add(counter, 0, 1)
out[slot] = vals[tid]


@wp.kernel(deterministic=True, deterministic_max_records=1)
def counter_kernel_deterministic(
vals: wp.array(dtype=wp.float32),
counter: wp.array(dtype=wp.int32),
out: wp.array(dtype=wp.float32),
):
tid = wp.tid()
slot = wp.atomic_add(counter, 0, 1)
out[slot] = vals[tid]


@wp.kernel
def zero_float_array_kernel(out: wp.array(dtype=wp.float32)):
tid = wp.tid()
out[tid] = 0.0


@wp.kernel
def zero_int_array_kernel(out: wp.array(dtype=wp.int32)):
tid = wp.tid()
out[tid] = 0


class AtomicMax:
"""Benchmark wp.atomic_max() with high thread contention.

Expand Down Expand Up @@ -166,3 +225,162 @@ def time_cuda(self, vals_np_dict, dtype_str):
self.out.zero_()
self.cmd.launch()
wp.synchronize_device(self.device)


class AtomicAddDeterminismOverhead:
"""Benchmark the overhead of deterministic accumulation atomics.

The benchmark compares the normal atomic-add path against deterministic
scatter-sort-reduce for the same kernel using CUDA graph replay. A small
size sweep exposes where deterministic execution crosses over. Two
destination counts are used:

- ``1``: worst-case contention, where every thread targets the same output.
- ``65536``: lower contention, closer to a scatter workload.
"""

params = (["normal", "deterministic"], [1, 65536], DETERMINISTIC_BENCHMARK_SIZES)
param_names = ["mode", "num_outputs", "num_elements"]
Comment on lines +242 to +243
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make ASV parameter metadata immutable.

params is a class attribute in both benchmark classes, and using lists here is what Ruff is flagging. Switching these to tuples avoids shared mutable state and keeps the benchmark module lint-clean.

♻️ Minimal fix
-    params = (["normal", "deterministic"], [1, 65536], DETERMINISTIC_BENCHMARK_SIZES)
+    params = (("normal", "deterministic"), (1, 65536), tuple(DETERMINISTIC_BENCHMARK_SIZES))
     param_names = ["mode", "num_outputs", "num_elements"]
@@
-    params = (["normal", "deterministic"], DETERMINISTIC_BENCHMARK_SIZES)
+    params = (("normal", "deterministic"), tuple(DETERMINISTIC_BENCHMARK_SIZES))
     param_names = ["mode", "num_elements"]

Also applies to: 317-318

🧰 Tools
🪛 Ruff (0.15.9)

[warning] 243-243: Mutable default value for class attribute

(RUF012)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@asv/benchmarks/atomics.py` around lines 242 - 243, The class attribute params
(and param_names) in the benchmark classes are defined as lists which creates
mutable class state; change them to tuples so they are immutable (e.g., replace
params = (["normal", "deterministic"], [1, 65536],
DETERMINISTIC_BENCHMARK_SIZES) with a tuple-of-tuples and param_names likewise)
and apply the same change to the second occurrence around lines 317-318; update
the definitions referenced as params and param_names in
asv/benchmarks/atomics.py so both benchmark classes use immutable tuples instead
of lists.


repeat = 10
number = 5

def setup_cache(self):
rng = np.random.default_rng(123)
vals_np = {n: rng.random(n, dtype=np.float32) for n in DETERMINISTIC_BENCHMARK_SIZES}
indices_np = {}
for n in DETERMINISTIC_BENCHMARK_SIZES:
indices_np[n] = {
1: np.zeros(n, dtype=np.int32),
65536: rng.integers(0, 65536, size=n, dtype=np.int32),
}
return vals_np, indices_np

def setup(self, cache, mode, num_outputs, num_elements):
wp.init()
self.device = wp.get_device("cuda:0")

vals_np, indices_np = cache
self.vals = wp.array(vals_np[num_elements], dtype=wp.float32, device=self.device)
self.indices = wp.array(indices_np[num_elements][num_outputs], dtype=wp.int32, device=self.device)
self.out = wp.zeros(shape=(num_outputs,), dtype=wp.float32, device=self.device)

self.kernel = scatter_add_kernel_deterministic if mode == "deterministic" else scatter_add_kernel
wp.launch(
zero_float_array_kernel,
dim=num_outputs,
inputs=[self.out],
device=self.device,
)
wp.launch(
self.kernel,
(num_elements,),
inputs=[self.vals, self.indices],
outputs=[self.out],
device=self.device,
)
wp.synchronize_device(self.device)

with wp.ScopedCapture(device=self.device, force_module_load=False) as capture:
wp.launch(
zero_float_array_kernel,
dim=num_outputs,
inputs=[self.out],
device=self.device,
)
wp.launch(
self.kernel,
(num_elements,),
inputs=[self.vals, self.indices],
outputs=[self.out],
device=self.device,
)

self.graph = capture.graph

for _ in range(5):
wp.capture_launch(self.graph)
wp.synchronize_device(self.device)

def time_cuda(self, cache, mode, num_outputs, num_elements):
wp.capture_launch(self.graph)
wp.synchronize_device(self.device)


class AtomicCounterDeterminismOverhead:
"""Benchmark the overhead of deterministic counter/allocator atomics.

The timed path uses CUDA graph replay and includes resetting the output
state inside the captured graph so the benchmark isolates device work.
"""

params = (["normal", "deterministic"], DETERMINISTIC_BENCHMARK_SIZES)
param_names = ["mode", "num_elements"]

repeat = 10
number = 5

def setup_cache(self):
rng = np.random.default_rng(321)
return {n: rng.random(n, dtype=np.float32) for n in DETERMINISTIC_BENCHMARK_SIZES}

def setup(self, vals_np, mode, num_elements):
wp.init()
self.device = wp.get_device("cuda:0")

self.vals = wp.array(vals_np[num_elements], dtype=wp.float32, device=self.device)
self.counter = wp.zeros(shape=(1,), dtype=wp.int32, device=self.device)
self.out = wp.zeros(shape=(num_elements,), dtype=wp.float32, device=self.device)

self.kernel = counter_kernel_deterministic if mode == "deterministic" else counter_kernel
wp.launch(
zero_int_array_kernel,
dim=1,
inputs=[self.counter],
device=self.device,
)
wp.launch(
zero_float_array_kernel,
dim=num_elements,
inputs=[self.out],
device=self.device,
)
wp.launch(
self.kernel,
(num_elements,),
inputs=[self.vals, self.counter],
outputs=[self.out],
device=self.device,
)
wp.synchronize_device(self.device)

with wp.ScopedCapture(device=self.device, force_module_load=False) as capture:
wp.launch(
zero_int_array_kernel,
dim=1,
inputs=[self.counter],
device=self.device,
)
wp.launch(
zero_float_array_kernel,
dim=num_elements,
inputs=[self.out],
device=self.device,
)
wp.launch(
self.kernel,
(num_elements,),
inputs=[self.vals, self.counter],
outputs=[self.out],
device=self.device,
)

self.graph = capture.graph

for _ in range(5):
wp.capture_launch(self.graph)
wp.synchronize_device(self.device)

def time_cuda(self, vals_np, mode, num_elements):
wp.capture_launch(self.graph)
wp.synchronize_device(self.device)
2 changes: 2 additions & 0 deletions build_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def main(argv: list[str] | None = None) -> int:
"native/texture.cpp",
"native/mathdx.cpp",
"native/coloring.cpp",
"native/deterministic.cpp",
"native/fastcall.cpp",
]
warp_cpp_paths = [os.path.join(build_path, cpp) for cpp in cpp_sources]
Expand All @@ -533,6 +534,7 @@ def main(argv: list[str] | None = None) -> int:
else:
cuda_sources = [
"native/bvh.cu",
"native/deterministic.cu",
"native/mesh.cu",
"native/sort.cu",
"native/hashgrid.cu",
Expand Down
Loading