Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inform when weights loading promotes types #126000

Open
isentropic opened this issue May 11, 2024 · 1 comment
Open

Inform when weights loading promotes types #126000

isentropic opened this issue May 11, 2024 · 1 comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@isentropic
Copy link

isentropic commented May 11, 2024

馃殌 The feature, motivation and pitch

This motivates my issue https://discuss.pytorch.org/t/state-dict-and-loading-of-buffers/202015/2
Here is the problem when you do loading of the weights:

def Module(nn.Module):
    def __init__(self, op_mode):
        super().__init__()
        self.register_buffer('running_max', torch.tensor(1)) # the tensor is int
    def forward(x):
         xmax = torch.max(torch.abs(x))
         self.running_max = (1 - self.momentum) * self.running_max + self.momentum * xmax # the tensor becomes float32
In [1]: net = Module()

In [21]: state_dict['running_max'] # the serialized dict contains float32
Out[21]: tensor(2.8699, device='cuda:1')

In [23]: net.load_state_dict(state_dict) # It feels like it loaded properly everything
Out[23]: <All keys matched successfully>

In [22]: net.running_max # In reality the buffer does not load correctly because it has a different type
Out[22]: tensor(2)

It should not be like this, as it is very hard to diagnose why the loaded model behaves differently now. The message should read like

n [23]: net.load_state_dict(state_dict) # It feels like it loaded properly everything
Out[23]: <All keys matched successfully, some keys 'running_max' need type promotion (demotion)> # or something like this

Alternatives

There maybe needs to be some sort of warning or notification instead of saying "All keys matched successfully", which is clearly not the case

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@isentropic
Copy link
Author

isentropic commented May 11, 2024

I'm sure this burned someone before, as it is not easy to debug this (the model would just silently perform bad) in case the model is deeply nested and you forget that you changed the type of some of the buffers or weights...

I'd like to try to send a PR for this if you think this is helpful and you could point to the right places to modify the code. Given that this is simple enough to contribute to

@mikaylagawarecki mikaylagawarecki added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants