[Compat] Make hybrid ep branch compatible with PaddlePaddle#12
[Compat] Make hybrid ep branch compatible with PaddlePaddle#12SigureMo wants to merge 2 commits intohybrid-ep-paddlefrom
Conversation
7febc6e to
a484936
Compare
91af98a to
c9fc08d
Compare
ShigureNyako
left a comment
There was a problem hiding this comment.
这版兼容改造我先不建议合入,主要有几点是“保留了 PyTorch 接口外形,但语义已经变了”,其中既有兼容性问题,也有明确的性能退化风险:
-
deep_ep.Buffer的comm路径已经失效。deep_ep/buffer.py:74-98, 128-133定义了all_gather_objecthelper,但后续同步阶段全部直接调用dist.all_gather_object(..., group),而且构造 runtime 时无条件取group.id(deep_ep/buffer.py:92-93)。- 这意味着文档里仍然声称支持
comm,但实际一走comm分支就会在group.id/dist.all_gather_object(..., group)处崩掉。 - 建议:要么把
comm路径完整接回(包括context_ring_id/ object gather),要么直接删掉该 API/文档并显式报错,避免半迁移状态。
-
current stream 语义已经和 PyTorch 版本不一致。
- 原实现这里用的是
at::cuda::getCurrentCUDAStream()/setCurrentCUDAStream(...),现在csrc/deep_ep.cpp多处改成固定读calc_ctx->stream(),并通过SetAllocatorStreamForGPUContext(...)切 allocator。 - 但
EventHandle仍然用at::cuda::getCurrentCUDAStream()录制/等待(csrc/event.hpp:13-29)。 - 结果就是“计算流”来源被拆成了 Paddle
GPUContext和 ATen current stream 两套定义。上层如果在自定义torch.cuda.Stream()上排 kernel,previous_event为空时这里不再自动等待调用者当前流,容易出现缺依赖或者额外串行化。 - 建议:至少保证 wait / allocate / record_stream 读取的是同一个 stream 源;如果必须走 Paddle stream,也需要把 current-stream 相关接口一起切过去,而不是混用。
- 原实现这里用的是
-
enable_custom_allgather参数被静默忽略,热路径会直接退化。- Python 侧还保留了
enable_custom_allgather参数(deep_ep/hybrid_ep_buffer.py:67),但创建 runtime 时直接硬编码成False(deep_ep/hybrid_ep_buffer.py:157-162)。 - 这样
Executor::allgather_routing_map会永久走paddle.distributed.stream.all_gather(...)分支,而不是原来的 custom intra-node allgather(csrc/hybrid_ep/executor/executor.cu:27-34)。 - 这既是接口语义变化,也是明确的性能风险。调用方即使显式传
True也拿不到原本的快路径。 - 建议:要么把参数删掉/显式报错,要么把 custom 路径按兼容层约束修好后继续透传,不要 silent fallback。
- Python 侧还保留了
-
HybridEP 的拓扑检测/校验被删掉后,默认“全员同一 NVLink 域”风险太高。
- 原来会用
ExtendedMemoryAllocator.detect_accessible_ranks()做检测;现在deep_ep/hybrid_ep_buffer.py:79-87直接默认num_of_hybrid_ep_ranks_per_nvlink_domain = world_size,deep_ep/buffer.py:65-66的 NVLink 检查也被注释掉了。 - 这在多节点、PCIe、或者局部不可达拓扑下不再 fail fast,而是可能直接走错路径。
- 建议:至少在自动检测不可用时强制要求显式传/设
NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN,不要默认world_size。
- 原来会用
附带一个建议尽快修的小点:csrc/hybrid_ep/hybrid_ep.cu:15-20 现在用重复的 group.id 生成 comm_id,而 csrc/hybrid_ep/jit/compiler.cuh:57 明确它的语义是 hash(all ranks in the process_group)。这会让 JIT cache key 丢失成员信息,最好还是按 group member 列表生成,避免后续 cache/句柄复用出现撞车。
| # Disable custom allgather by default because its data layout is incompatible with scan kernel | ||
| # The custom allgather kernel produces token-interleaved layout, but scan kernel expects | ||
| # the standard allgather layout (rank-blocked layout) | ||
| enable_custom_allgather = False # Always use standard allgather for correctness |
There was a problem hiding this comment.
这里把 enable_custom_allgather 直接硬编码成 False,会让 Python 暴露的参数失效,并且把原来的 intra-node fast path 永久降级到通用 all_gather。如果当前兼容层暂时不支持 custom allgather,建议至少在 Python 侧显式报错/移除该参数,而不是 silent fallback。
| self.explicitly_destroy = explicitly_destroy | ||
| self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, | ||
| self.disable_nvlink_for_normal_mode,explicitly_destroy, use_fabric) | ||
| self.disable_nvlink_for_normal_mode,explicitly_destroy, use_fabric, group.id) |
There was a problem hiding this comment.
这里已经把 group.id 当成必选输入传进 C++ runtime 了,但这个类的文档和分支逻辑仍然保留了 comm 兜底路径。后面同步阶段也直接写死成 dist.all_gather_object(..., group),不再走上面定义的 all_gather_object helper。这样一来 comm 分支实际上已经不可用了,属于典型的半迁移状态。建议要么把 comm 路径完整接回,要么在接口层显式去掉/报错,不要保留一个看起来可用但实际会崩的 API。
| else: | ||
| self.num_of_hybrid_ep_ranks_per_nvlink_domain = detected_ranks | ||
| # Default: assume all ranks are in the same NVLink domain (single node) | ||
| self.num_of_hybrid_ep_ranks_per_nvlink_domain = self.group_size |
There was a problem hiding this comment.
这里直接把默认值设成 group_size,等价于假设“所有 rank 都在同一个 NVLink 域”。原实现会先探测 accessible ranks,这里删掉以后,多节点 / PCIe / 局部不可达拓扑都不再 fail fast,而是可能静默走错路径。建议至少在自动探测不可用时要求显式设置 NUM_OF_HYBRID_EP_RANKS_PER_NVLINK_DOMAIN,不要默认全员可达。
| // Allocate all tensors on comm stream if set | ||
| // NOTES: do not allocate tensors upfront! | ||
| auto compute_stream = at::cuda::getCurrentCUDAStream(); | ||
| auto compute_stream = calc_ctx->stream(); |
There was a problem hiding this comment.
这里把 compute_stream 从 at::cuda::getCurrentCUDAStream() 改成了固定的 calc_ctx->stream(),但 EventHandle 仍然在 at::cuda::getCurrentCUDAStream() 上 record/wait。这样 current-stream 语义已经分裂成两套来源:上层如果在自定义 torch.cuda.Stream() 上排 kernel,previous_event 为空时这里不再自动等待调用者当前流,可能出现缺依赖或额外串行化。建议把 wait / allocate / record_stream 统一到同一个 stream 源。
No description provided.