Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions dlblas/kernels/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"):
# Create a new module based on that spec
temp_module = importlib.util.module_from_spec(spec)
# Execute the code in the module's namespace
# 如果生成的triton中有device信息, 最好和将要运行在的目标平台一至,否则这里的执行会出问题
# 比如目标平台不兼容cuda,但triton中的device = 'cuda',类似报错为 "init with exception: Torch not compiled with CUDA enabled"
spec.loader.exec_module(temp_module)

ModelNew = getattr(temp_module, entry_point)
Expand Down Expand Up @@ -140,7 +142,7 @@ def _move_to_device(obj, device):
• 其它 -> 原样返回
"""
if isinstance(obj, torch.Tensor):
return obj.item() if obj.numel() == 1 else obj.to(device, non_blocking=True)
return obj.to(device, non_blocking=True) #obj.item() if obj.numel() == 1 else

if isinstance(obj, (list, tuple)):
return type(obj)(_move_to_device(x, device) for x in obj)
Expand Down Expand Up @@ -242,7 +244,7 @@ def main():
# defined here
device = 'cuda'
root_path = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels"
output_file = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels/output_{device}.json"
output_file = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels/output_{device}.json


# init
Expand Down Expand Up @@ -278,10 +280,16 @@ def main():
set_seed(seed_num) # set seed for reproducible input
# ---------- 解析 get_init_inputs ----------
raw_init_inputs = get_init_inputs() if get_init_inputs else []
init_args, init_kwargs = _parse_init_inputs(raw_init_inputs)
init_args_ori, init_kwargs_ori = _parse_init_inputs(raw_init_inputs)
init_args = init_args_ori
init_kwargs = init_kwargs_ori
# 防止初始化参数中带有的device信息(用来init original model的参数)和目标平台device不一致
init_args=[device if i == 'cpu' else i for i in init_args]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

通过_ori参数保持原始模型的参数将传入原始模型,不带_ori后缀的参数是传入给triton算子的;如果初始化的参数信息中带有device信息,将传入给triton算子的device 替换为目标平台device类型。


# 把 tensor 放到指定 device
init_args = _move_to_device(init_args, device)
init_kwargs = _move_to_device(init_kwargs, device)

except Exception as e:
print(f"{item['uid']} init with exception: {e}", flush=True)
correctness = False
Expand All @@ -290,18 +298,21 @@ def main():
try:
with torch.no_grad():
set_seed(seed_num) # set seed for reproducible weights
original_model = Model(*init_args, **init_kwargs)
original_model = Model(*init_args_ori, **init_kwargs_ori)
assert hasattr(original_model, "forward")
original_model=original_model.to(device)
# original_model=original_model.to(device)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原始模型跑在cpu上。

with torch.no_grad():
set_seed(seed_num) # set seed for reproducible weights
custom_model = ModelNew(*init_args, **init_kwargs)
assert hasattr(custom_model, "forward")
custom_model=custom_model.to(device)
inputs = get_inputs()
inputs_ori = inputs
inputs = _move_to_device(inputs, device)
output = original_model(*inputs)
output = original_model(*inputs_ori)
output = _move_to_device(output, device)
output_new = custom_model(*inputs)

outputs = (output,) if not isinstance(output, tuple) else output
outputs_new = (output_new,) if not isinstance(output_new, tuple) else output_new
if len(outputs) != len(outputs_new):
Expand Down Expand Up @@ -335,4 +346,4 @@ def main():
print(f"保存 JSON 失败: {e}", flush=True)

if __name__ == "__main__":
main()
main()