Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Zero3: Post backward hook is not triggered for submodules whose inputs have .required_grad=False #5524

Open
deepcharm opened this issue May 12, 2024 · 1 comment
Labels
bug Something isn't working training

Comments

@deepcharm
Copy link
Contributor

deepcharm commented May 12, 2024

Describe the bug
The mechanism of pre-backward and post-backward hooks employs adding a custom autograd function class on tensors, which are either inputs to the module (for post-backward) or outputs of the module (for pre-backward).

When the forward method of the post-backward function is invoked, it saves the module and counts the number of input tensors.

Consequently, when its backward method is invoked, the counter decreases for each tensor, and once it reaches zero, the actual post backward processing routine is invoked. The main purpose of that routine being the release of the previously materialized module parameters.

The above mechanism works for all the modules in a model, except for those whose inputs have .requires_grad being False. Typically, these are the very first modules in the model.

Since, no gradient calculation is required for such inputs, the backward method of the above custom autograd function is NOT called.

image

As a result, the release_submodule is not called for those modules, causing memory being not released (and potentially not cleaning the params state correctly).

For example, the BERT model has 3 Embedding modules of significant size (> GB of memory) who directly receive their inputs from a dataloader. The release_submodule will not be called for these modules in the current design, causing a memory peak.

The same would happen for ANY module whose inputs have .requires_grad False and not necessarily the very first modules.

To Reproduce
This can be easily reproduced on any model, such as below. The submodules linear0_0 and linear0_1 of the model MyModel are receiving inputs directly. The last submodule linear1 is receiving inputs from the first 2 layers.

class MyModel(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    super().__init__()
    self.linear0_0 = torch.nn.Linear(D_in, H)
    self.linear0_1 = torch.nn.Linear(D_in, H)  
    self.linear1 = torch.nn.Linear(H, D_out)

  def forward(self, x):
    y = torch.add(self.linear0_0(x), self.linear0_1(x)).clamp(min=0)
    y = self.linear1(y)
    return y

One can observe (by adding appropriate debug prints), that in the backward pass release_submodule is not invoked for the submodules linear0_0 and linear0_1, while it is invoked as expected for the submodule linear1.

@deepcharm deepcharm added bug Something isn't working training labels May 12, 2024
@deepcharm
Copy link
Contributor Author

A brutal force solution is to enforce the .requires_grad to be True for the model input tensors:

        class PostBackwardFunctionModule(torch.autograd.Function):

            @staticmethod
            def forward(ctx, output):
                ctx.module = module

                if not output.requires_grad:
                    output.requires_grad_(requires_grad=True)
                    output.mark_as_no_grad = True

The .requires_grad value can be then restored to its original in the PostBackwardFunctionModule::backward.
This method works, but seems to be hacky and may introduce some unexpected changes in the torch autograd mechanism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant