Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make adding buffers more like adding parameters to modules. #35735

Closed
josh-gleason opened this issue Mar 31, 2020 · 7 comments 路 May be fixed by #125971
Closed

Make adding buffers more like adding parameters to modules. #35735

josh-gleason opened this issue Mar 31, 2020 · 7 comments 路 May be fixed by #125971
Labels
actionable enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@josh-gleason
Copy link

josh-gleason commented Mar 31, 2020

馃殌 Feature

Add a nn.Buffer type to mirror the behavior of nn.Parameter without the need to explicity call nn.Module.register_buffer.

Motivation

It's currently intuitive and easy to add a parameter to an nn.Module by wrapping it in a nn.Parameter. To the best of my knowledge a buffer is very similar to a parameter from an end user perspective except it doesn't get returned by nn.Module.parameters().

It would therefore make sense to have a similar method for adding buffers to modules. Currently you have to explicitly call nn.Module.register_buffer which, in my opinion, is not very elegant as it requires you to provide the name of the member variable you want to create as a string.

# currently
class Foo(nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.weights = nn.Parameter(torch.zeros(10, 10))
        self.register_buffer('my_buffer', torch.zeros(10, 10))

# proposed
class Foo(nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.weights = nn.Parameter(torch.zeros(10, 10))
        self.my_buffer = nn.Buffer(torch.zeros(10, 10))

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are

@vincentqb vincentqb added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn labels Mar 31, 2020
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 1, 2020
@Balandat
Copy link
Contributor

Balandat commented Jul 3, 2020

Is anyone working on this? Would be nice to have this to streamline dealing with buffers.

@asanakoy
Copy link
Contributor

I agree. It Would be more consistent to have nn.Buffer along with nn.Parameter

@r-barnes
Copy link
Contributor

r-barnes commented Aug 6, 2021

This would be a great thing to have.

@osmalpkoras
Copy link

Is this coming anytime soon?

@albanD
Copy link
Collaborator

albanD commented Feb 28, 2023

I'm afraid no-one is working on this at the moment no.
But I would be happy to review a PR adding this if anyone wants to do it!

@ekamiti
Copy link
Contributor

ekamiti commented Jul 1, 2023

I'm afraid no-one is working on this at the moment no. But I would be happy to review a PR adding this if anyone wants to do it!

@albanD Just posted a PR for this. Mostly test changes, actual product code changes are not that much. Let me know what you think and whether I've missed any important considerations. Thanks!

@fzimmermann89
Copy link
Contributor

I believe this has in the end not been merged, right?

Is there any chance that the PR will be merged or a different implementation will be considered?

ekamiti added a commit to ekamiti/pytorch that referenced this issue May 10, 2024
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes pytorch#35735
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet