Fix workspace sizing for SplitKV in standalone PA accuracy/bench tests#683
Fix workspace sizing for SplitKV in standalone PA accuracy/bench tests#683chenshengxin2026 wants to merge 3 commits intohw-native-sys:mainfrom
Conversation
The workspace_sizes() function used a per-core heuristic for l_gm and o_core_tmp_gm that under-allocated for large-batch SplitKV cases. For the new b256_h16_kv1_s8192_bs128 case (tiling_key=16, kvCN=22), the old formula gave ~3.7 MB for o_core_tmp while the kernel needs ~46 MB, causing out-of-bounds writes. Changes: - make_pa_nd_decode_tiling() now returns kv_split_core_num so callers can size workspace precisely. - workspace_sizes() takes kv_cn and sizes l/o_core_tmp by batch * kv_cn * num_heads instead of block_dim * SPLITKV_RATIO. - Update callers in test_pa_accuracy.py and bench_pa_performance.py. - Enable the b256_h16_kv1_s8192_bs128 benchmark case.
0ad00ee to
bfed4cc
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a standalone paged-attention kernel implementation for NPU, encompassing the CCE kernel code, a C++ wrapper, and a Python port of the tiling logic. It also includes scripts for compilation, performance benchmarking, and accuracy verification. The review feedback identifies critical issues in the tiling logic, specifically regarding incorrect workspace offset and size calculations for SplitKV scenarios which could lead to memory corruption or under-allocation. Additionally, the reviewer recommends re-enabling the benchmark and accuracy test cases that were commented out in this version.
| tiling[base + 15] = _hi32(addr_ofd) | ||
| tiling[base + 16] = _lo32(addr_ofd) | ||
| addr_l += kvCN * num_heads * q_seqlen | ||
| addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd |
There was a problem hiding this comment.
The addr_ofd element offset increment is missing the kvCN factor. For SplitKV cases (kvCN > 1), the intermediate workspace o_core_tmp_gm must store partial results for each split. Without this factor, the workspace regions for different batch items will overlap, leading to memory corruption and incorrect results. Additionally, since o_core_tmp_gm stores partials of the output tensor, it should use head_dim_v to ensure correct sizing when head_dim != head_dim_v.
| addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd | |
| addr_ofd += kvCN * num_heads * head_dim_v * q_seqlen # embeddingSize for oFd |
| basic_float = block_dim * WORKSPACE_BLOCK_SIZE_DB * 4 # basicWorkSpaceFloat | ||
| # SplitKV paths need global workspace sized by (batch * kv_cn * num_heads). | ||
| # The old per-core estimate under-allocates for large-batch split-KV cases. | ||
| o_core = max(16, batch * kv_cn * num_heads * head_dim * 4) |
There was a problem hiding this comment.
The o_core workspace size should be calculated using head_dim_v instead of head_dim. The o_core_tmp_gm buffer stores intermediate partial sums of the attention output, which has a hidden dimension of head_dim_v. Using head_dim will result in an under-allocation if head_dim_v > head_dim.
| o_core = max(16, batch * kv_cn * num_heads * head_dim * 4) | |
| o_core = max(16, batch * kv_cn * num_heads * head_dim_v * 4) |
| # ("Qwen3-0.6B b1 h16/kv8 kv2048", 1, 16, 8, 128, 2048, 128), | ||
| # ("Qwen3-1.7B b1 h16/kv8 kv4096", 1, 16, 8, 128, 4096, 128), | ||
| # ("Qwen3-4B b1 h32/kv8 kv2048", 1, 32, 8, 128, 2048, 128), | ||
| # ("Qwen3-8B b1 h32/kv8 kv4096", 1, 32, 8, 128, 4096, 128), | ||
| # ("Qwen3-8B b1 h32/kv8 kv8192", 1, 32, 8, 128, 8192, 128), | ||
| # ("Qwen3-14B b1 h40/kv8 kv2048", 1, 40, 8, 128, 2048, 128), | ||
| # ("Qwen3-32B b1 h64/kv8 kv2048", 1, 64, 8, 128, 2048, 128), | ||
| # ("MHA b1 h32/kv32 kv2048", 1, 32, 32, 128, 2048, 128), | ||
| # ("Qwen3-8B b4 h32/kv8 kv2048", 4, 32, 8, 128, 2048, 128), | ||
| # ("Qwen3-8B b8 h32/kv8 kv2048", 8, 32, 8, 128, 2048, 128), | ||
| # ("Qwen3-8B b16 h32/kv8 kv2048", 16, 32, 8, 128, 2048, 128), | ||
| # ("Qwen3-8B b32 h32/kv8 kv2048", 32, 32, 8, 128, 2048, 128), | ||
| # ("Qwen3-8B b64 h32/kv8 kv2048", 64, 32, 8, 128, 2048, 128), |
There was a problem hiding this comment.
| # {"name": "b1_h32_kv8_s128_bs128", "batch": 1, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 128, "block_size": 128}, | ||
| # # Multiple batches | ||
| # {"name": "b4_h32_kv8_s512_bs128", "batch": 4, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 512, "block_size": 128}, | ||
| # # MHA (nq == nkv) — uses split-KV path (tiling_key=16), isolated in subprocess | ||
| # {"name": "b2_h8_kv8_s256_bs128", "batch": 2, "num_heads": 8, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 256, "block_size": 128}, | ||
| # # Larger GQA | ||
| # {"name": "b8_h32_kv8_s1024_bs128", "batch": 8, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 1024, "block_size": 128}, | ||
| # # Qwen3 shapes | ||
| # {"name": "b1_h32_kv8_s2048_bs128", "batch": 1, "num_heads": 32, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 2048, "block_size": 128}, | ||
| # {"name": "b4_h64_kv8_s1024_bs128", "batch": 4, "num_heads": 64, "num_kv_heads": 8, "head_dim": 128, "kv_seq": 1024, "block_size": 128}, |
Summary
Fix workspace sizing for SplitKV in standalone PA accuracy/bench tests.
The
workspace_sizes()function used a per-core heuristic forl_gmando_core_tmp_gmthat under-allocated for large-batch SplitKV cases. For the newb256_h16_kv1_s8192_bs128case (tiling_key=16, kvCN=22), the old formula gave ~3.7 MB foro_core_tmpwhile the kernel needs ~46 MB, causing out-of-bounds writes.Changes
make_pa_nd_decode_tiling()now returnskv_split_core_numso callers can size workspace precisely.workspace_sizes()takeskv_cnand sizesl/o_core_tmpbybatch * kv_cn * num_headsinstead ofblock_dim * SPLITKV_RATIO.test_pa_accuracy.pyandbench_pa_performance.py.b256_h16_kv1_s8192_bs128benchmark case.closes #677