-
Notifications
You must be signed in to change notification settings - Fork 13
change validate.py #109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
junelotus
wants to merge
1
commit into
DeepLink-org:add-kernelbench-triton-gpt5high
Choose a base branch
from
junelotus:add-kernelbench-triton-gpt5high
base: add-kernelbench-triton-gpt5high
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
change validate.py #109
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,8 @@ def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"): | |
| # Create a new module based on that spec | ||
| temp_module = importlib.util.module_from_spec(spec) | ||
| # Execute the code in the module's namespace | ||
| # 如果生成的triton中有device信息, 最好和将要运行在的目标平台一至,否则这里的执行会出问题 | ||
| # 比如目标平台不兼容cuda,但triton中的device = 'cuda',类似报错为 "init with exception: Torch not compiled with CUDA enabled" | ||
| spec.loader.exec_module(temp_module) | ||
|
|
||
| ModelNew = getattr(temp_module, entry_point) | ||
|
|
@@ -140,7 +142,7 @@ def _move_to_device(obj, device): | |
| • 其它 -> 原样返回 | ||
| """ | ||
| if isinstance(obj, torch.Tensor): | ||
| return obj.item() if obj.numel() == 1 else obj.to(device, non_blocking=True) | ||
| return obj.to(device, non_blocking=True) #obj.item() if obj.numel() == 1 else | ||
|
|
||
| if isinstance(obj, (list, tuple)): | ||
| return type(obj)(_move_to_device(x, device) for x in obj) | ||
|
|
@@ -242,7 +244,7 @@ def main(): | |
| # defined here | ||
| device = 'cuda' | ||
| root_path = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels" | ||
| output_file = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels/output_{device}.json" | ||
| output_file = f"/datapool/zmz/04kernelagent/caizheng/DLBlas-add-kernelbench-triton-gpt5high/dlblas/kernels/output_{device}.json | ||
|
|
||
|
|
||
| # init | ||
|
|
@@ -278,10 +280,16 @@ def main(): | |
| set_seed(seed_num) # set seed for reproducible input | ||
| # ---------- 解析 get_init_inputs ---------- | ||
| raw_init_inputs = get_init_inputs() if get_init_inputs else [] | ||
| init_args, init_kwargs = _parse_init_inputs(raw_init_inputs) | ||
| init_args_ori, init_kwargs_ori = _parse_init_inputs(raw_init_inputs) | ||
| init_args = init_args_ori | ||
| init_kwargs = init_kwargs_ori | ||
| # 防止初始化参数中带有的device信息(用来init original model的参数)和目标平台device不一致 | ||
| init_args=[device if i == 'cpu' else i for i in init_args] | ||
|
|
||
| # 把 tensor 放到指定 device | ||
| init_args = _move_to_device(init_args, device) | ||
| init_kwargs = _move_to_device(init_kwargs, device) | ||
|
|
||
| except Exception as e: | ||
| print(f"{item['uid']} init with exception: {e}", flush=True) | ||
| correctness = False | ||
|
|
@@ -290,18 +298,21 @@ def main(): | |
| try: | ||
| with torch.no_grad(): | ||
| set_seed(seed_num) # set seed for reproducible weights | ||
| original_model = Model(*init_args, **init_kwargs) | ||
| original_model = Model(*init_args_ori, **init_kwargs_ori) | ||
| assert hasattr(original_model, "forward") | ||
| original_model=original_model.to(device) | ||
| # original_model=original_model.to(device) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 原始模型跑在cpu上。 |
||
| with torch.no_grad(): | ||
| set_seed(seed_num) # set seed for reproducible weights | ||
| custom_model = ModelNew(*init_args, **init_kwargs) | ||
| assert hasattr(custom_model, "forward") | ||
| custom_model=custom_model.to(device) | ||
| inputs = get_inputs() | ||
| inputs_ori = inputs | ||
| inputs = _move_to_device(inputs, device) | ||
| output = original_model(*inputs) | ||
| output = original_model(*inputs_ori) | ||
| output = _move_to_device(output, device) | ||
| output_new = custom_model(*inputs) | ||
|
|
||
| outputs = (output,) if not isinstance(output, tuple) else output | ||
| outputs_new = (output_new,) if not isinstance(output_new, tuple) else output_new | ||
| if len(outputs) != len(outputs_new): | ||
|
|
@@ -335,4 +346,4 @@ def main(): | |
| print(f"保存 JSON 失败: {e}", flush=True) | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过_ori参数保持原始模型的参数将传入原始模型,不带_ori后缀的参数是传入给triton算子的;如果初始化的参数信息中带有device信息,将传入给triton算子的device 替换为目标平台device类型。