Skip to content

Llama3 like weight init#435

Merged
le1nux merged 8 commits intoimprove_data_writeout_perffrom
llama3_like_weight_init
Mar 6, 2026
Merged

Llama3 like weight init#435
le1nux merged 8 commits intoimprove_data_writeout_perffrom
llama3_like_weight_init

Conversation

@le1nux
Copy link
Member

@le1nux le1nux commented Mar 4, 2026

What does this PR do?

This PR ..

General Changes

  • ..

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

@le1nux le1nux marked this pull request as ready for review March 4, 2026 17:55
@le1nux le1nux requested a review from AbasKhan March 4, 2026 18:52
@le1nux le1nux changed the base branch from main to improve_data_writeout_perf March 4, 2026 19:22
match_count += 1
hits[weight_regex] += 1
if match_count == 0:
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a flag which turns this into an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the norms are initialized within the model factory via reset_parametersthis would always throw an error.

b=2,
),
# final attention projection in attention block
r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This corresponds to following right ?, but in there you can see for out projection its std=init_std , which can be intialized differently and defaults to depth_init , because here we pass weight_init_std , which default to depth_init in titan here. If we dont want depth init then it matches scaled out_projections logic when depth_init is False for titan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented depth_init to be fully compliant

def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None:
super().__init__()

self.regex_to_init = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need regex patterns for attention_norm, ffn_norm, and the final lm_head_normnai ?. Something like

r"transformer\.h\.\d+\.(attention_norm|ffn_norm)\.weight": nn.init.ones_,
r"transformer\.lm_head_norm\.weight": nn.init.ones_,

Copy link
Member Author

@le1nux le1nux Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module.reset_parameters()

we already call this here.

and due to recursion we also call it for the RMSNorm.
https://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407

@le1nux le1nux requested review from AbasKhan and BlueCrescent March 5, 2026 21:47
if re.fullmatch(weight_regex, parameter_name):
init_fn, arg_dict = regex_to_init[weight_regex]
if arg_dict["std"] is not None and callable(arg_dict["std"]):
if not depth_init:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt this dead code now ? , std becomes a callable only when depth_init is True right ?, so this check is not needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it as a safety check but you're right it's kinda redundant. I'll remove it!

@le1nux le1nux merged commit 6a17097 into improve_data_writeout_perf Mar 6, 2026
3 checks passed
@le1nux le1nux deleted the llama3_like_weight_init branch March 6, 2026 09:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants