Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement a custom model with new modules implemented by yourself #11690

Open
BruceLxw opened this issue May 8, 2024 · 2 comments
Open

Reimplement a custom model with new modules implemented by yourself #11690

BruceLxw opened this issue May 8, 2024 · 2 comments
Assignees
Labels
reimplementation Issues in model reimplementation

Comments

@BruceLxw
Copy link

BruceLxw commented May 8, 2024

I want to add an attention module to ResNet50. How should I change my attention module (source code) and then add it to ResNet50 in mmdetection. Also, my previous models were all using "open- mmlab://detectron2/resnet50_caffe“,How to set the initialization of this attention module after adding its pre-training weights so that other resnet50 parts remain the same as the initialization model mentioned above? Thank you very much for your answer, thank you!

@BruceLxw BruceLxw added the reimplementation Issues in model reimplementation label May 8, 2024
@BruceLxw
Copy link
Author

The specific issue is that I have added an attention module after the second convolution of BasicBlock and the third convolution of BottleNeck, while the rest are ResNet50 original versions of mmdetection. Therefore, if I want to continue using the pre-training model mentioned above in the original ResNet50 section, and the newly added attention module uses simple initialization methods such as Kaiming, how should I implement it? Thank you

@BruceLxw
Copy link
Author

#mmdet\models\backbones\mynet.py

class MyA(BaseModule):
    def __init__(self, channels, factor=32,
                 norm_cfg=None,
                 conv_cfg=None,
                 init_cfg= [
                    dict(type='Kaiming', layer='Conv2d'),
                    dict(
                        type='Constant',
                        val=1,
                        layer=['_BatchNorm', 'GroupNorm'])
                ],):
        super(MyA, self).__init__(init_cfg)
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        norm_cfg = dict(type='GN', num_groups=self.groups),
        self.conv1x1 = build_conv_layer(
            conv_cfg,
            channels // self.groups,
            channels // self.groups,
            1,
            stride=1,
            padding=0,
            dilation=1,
            bias=False)
        
        self.conv3x3 = build_conv_layer(
            conv_cfg,
            channels // self.groups,
            channels // self.groups,
            3,
            stride=1,
            padding=1,
            dilation=1,
            bias=False)

        self.norm1_name, norm1 = build_norm_layer(norm_cfg, channels // self.groups, postfix=1)
        self.add_module(self.norm1_name, norm1)
    @property
    def norm1(self):
        """nn.Module: normalization layer after the first convolution layer"""
        return getattr(self, self.norm1_name)
    
    def forward(self, x):
        b, c, h, w = x.size()

        group_x = x.reshape(b * self.groups, -1, h, w)
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)

        x1 = self.norm1(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)

        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)
        y1 = torch.matmul(x11, x12)

        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)
        y2 = torch.matmul(x21, x22)

        weights = (y1+y2).reshape(b * self.groups, 1, h, w) 
        weights_ =  weights.sigmoid()
        out = (group_x * weights_).reshape(b, c, h, w)
        return out

#...

class MyABasicBlock(BaseModule):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None,
                 plugins=None,
                 init_cfg=None):
        super(EMABasicBlock, self).__init__(init_cfg)
	    #...
        self.mya = MyA(channels=planes)
        
    #...
    
    def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x

            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            out = self.conv2(out)
            out = self.norm2(out)
            # 加入MyA
            out = self.mya(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out

#...
class MyABottleneck(BaseModule):
    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 dilation=1,
                 downsample=None,
                 style='pytorch',
                 with_cp=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 dcn=None,
                 plugins=None,
                 init_cfg=None):
        """Bottleneck block for ResNet.

        If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
        it is "caffe", the stride-two layer is the first 1x1 conv layer.
        """
        super(MyABottleneck, self).__init__(init_cfg)
    	#...
        self.mya = MyA(planes * self.expansion)
	
    #...
    
    def forward(self, x):
        """Forward function."""

        def _inner_forward(x):
            identity = x
            out = self.conv1(x)
            out = self.norm1(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv1_plugin_names)

            out = self.conv2(out)
            out = self.norm2(out)
            out = self.relu(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv2_plugin_names)

            out = self.conv3(out)
            out = self.norm3(out)

            # 加入MyA
            out = self.mya(out)

            if self.with_plugins:
                out = self.forward_plugin(out, self.after_conv3_plugin_names)

            if self.downsample is not None:
                identity = self.downsample(x)

            out += identity

            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)

        out = self.relu(out)

        return out

I did not modify the rest of the ResNet Class. I have already passed in the cfg initialized by Kaiming from MyA, and now I still want to use my previous pre training configuration:

backbone=dict(
    norm_cfg=dict(requires_grad=False),
    norm_eval=True,
    style="caffe",
    init_cfg=dict(
        type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"
    ),
),

Is it okay for me to do this? Or how can we use this pre trained model on the original ResNet section, while initializing the MyA module with Kaiming and GN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
reimplementation Issues in model reimplementation
Projects
None yet
Development

No branches or pull requests

2 participants