Skip to content

change validate.py#109

Open
junelotus wants to merge 1 commit intoDeepLink-org:add-kernelbench-triton-gpt5highfrom
junelotus:add-kernelbench-triton-gpt5high
Open

change validate.py#109
junelotus wants to merge 1 commit intoDeepLink-org:add-kernelbench-triton-gpt5highfrom
junelotus:add-kernelbench-triton-gpt5high

Conversation

@junelotus
Copy link
Copy Markdown
Collaborator

No description provided.

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

init_args = init_args_ori
init_kwargs = init_kwargs_ori
# 防止初始化参数中带有的device信息(用来init original model的参数)和目标平台device不一致
init_args=[device if i == 'cpu' else i for i in init_args]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

通过_ori参数保持原始模型的参数将传入原始模型,不带_ori后缀的参数是传入给triton算子的;如果初始化的参数信息中带有device信息,将传入给triton算子的device 替换为目标平台device类型。

original_model = Model(*init_args_ori, **init_kwargs_ori)
assert hasattr(original_model, "forward")
original_model=original_model.to(device)
# original_model=original_model.to(device)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原始模型跑在cpu上。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants