Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slimming DenseNet #3

Open
haithanhp opened this issue Dec 23, 2017 · 7 comments
Open

Slimming DenseNet #3

haithanhp opened this issue Dec 23, 2017 · 7 comments

Comments

@haithanhp
Copy link

Hi @liuzhuang13,

Thank you for a great work. I saw that you leveraged scaling factors of Batch normalization to prune incoming and outgoing weights at conv layers, However in DenseNet after a basic block (1x1 + 3x3) the previous features is concatenated to the current one and the dimension of scaling factors is not matched to that of the previous convolutional layer for pruning. So, How can you prune weights in this case?

By the way, when training sparsity DenseNet is finished with lambda 1e-5, I notice that many scaling factors are not small enough for pruning. Does this affect to the performance of compressed network?

Thanks,
Hai

@liuzhuang13
Copy link
Owner

Thanks for your interest. We prune channels according to the BN's scaling factors, and after this process we set small factors (and biases) to 0, then we see which channels we can prune without affecting the network. This is applied to all network structures. In DenseNet actually dimension of scaling factors match the dimension of the convolution, because of the "pre-activation" structure.

The lambda parameter needs tuning for different datasets and hyperparameters (e.g. learning rate), so you may need to see the final performance.

@haithanhp
Copy link
Author

haithanhp commented Dec 28, 2017

Thanks for your answer. I have an example of one part of DenseNet-40 (k=12):

module.features.init_conv.weight : torch.Size([24, 3, 3, 3])
module.features.denseblock_1.dense_basicblock_1.conv_33.norm.weight : torch.Size([24])
module.features.denseblock_1.dense_basicblock_1.conv_33.norm.bias : torch.Size([24])
module.features.denseblock_1.dense_basicblock_1.conv_33.norm.running_mean : torch.Size([24])
module.features.denseblock_1.dense_basicblock_1.conv_33.norm.running_var : torch.Size([24])
module.features.denseblock_1.dense_basicblock_1.conv_33.conv.weight : torch.Size([12, 24, 3, 3])
module.features.denseblock_1.dense_basicblock_2.conv_33.norm.weight : torch.Size([36])
module.features.denseblock_1.dense_basicblock_2.conv_33.norm.bias : torch.Size([36])
module.features.denseblock_1.dense_basicblock_2.conv_33.norm.running_mean : torch.Size([36])
module.features.denseblock_1.dense_basicblock_2.conv_33.norm.running_var : torch.Size([36])
module.features.denseblock_1.dense_basicblock_2.conv_33.conv.weight : torch.Size([12, 36, 3, 3])
module.features.denseblock_1.dense_basicblock_3.conv_33.norm.weight : torch.Size([48])
module.features.denseblock_1.dense_basicblock_3.conv_33.norm.bias : torch.Size([48])
module.features.denseblock_1.dense_basicblock_3.conv_33.norm.running_mean : torch.Size([48])
module.features.denseblock_1.dense_basicblock_3.conv_33.norm.running_var : torch.Size([48])
module.features.denseblock_1.dense_basicblock_3.conv_33.conv.weight : torch.Size([12, 48, 3, 3])
module.features.denseblock_1.dense_basicblock_4.conv_33.norm.weight : torch.Size([60])
module.features.denseblock_1.dense_basicblock_4.conv_33.norm.bias : torch.Size([60])
module.features.denseblock_1.dense_basicblock_4.conv_33.norm.running_mean : torch.Size([60])
module.features.denseblock_1.dense_basicblock_4.conv_33.norm.running_var : torch.Size([60])
module.features.denseblock_1.dense_basicblock_4.conv_33.conv.weight : torch.Size([12, 60, 3, 3])

[N, C, K, K]: [#filters, #channels, kernel_size, kernel_size]

"norm.weight" here is the scaling factor in batch normalization. For me, each norm.weight layer I try to prune 40% #channels of batch normalization coresponding to #filters of previous conv.weight and #channels of latter conv.weight. How can you prune incoming and outgoing in this case? Please correct me if I make mistakes in pruning.

By the way, When parameters of layers are pruned, how does it affect to the performance of network? Is there any way to track how the performance changes ?

Thanks.

@liuzhuang13
Copy link
Owner

  1. In this basic DenseNet you can only prune outgoing weights. For example, if you set 10 of the 36 weights and biases in these
    module.features.denseblock_1.dense_basicblock_2.conv_33.norm.weight : torch.Size([36])
    module.features.denseblock_1.dense_basicblock_2.conv_33.norm.bias : torch.Size([36])
    to zeros, you can prune away the corresponding weights (the second dimension) in
    module.features.denseblock_1.dense_basicblock_2.conv_33.conv.weight : torch.Size([12, 36, 3, 3]).

  2. Maybe you could visualize the scaling parameters like in Fig. 4 in the paper. Or you could monitor the performance on a validation set. Based on my experience it is not very hard to pick the value.

@haithanhp
Copy link
Author

  1. When the second dimension of conv.weight is pruned to 26 (prune away 10), the dimension of input activation is still 36 and it won't be matched. How can you do convolution operator in this case?

  2. Thank you, I also try to visualize the values with lambda lasso of 1e-5 and 1e-4 and there are many values near zero.

@liuzhuang13
Copy link
Owner

  1. I wrote a channel selection layer and place it before the batch normalization layer. This layer selects the channels using the index of selected channels as the parameter. But in my implementation, it is very slow to run, maybe because of the memory copy involved. I'm not sure whether there is a solution for fast channel selection.

@haithanhp
Copy link
Author

Yes, I see. Also, do you public the code for DenseNet and Resnet experiments? I also need to reproduce all your experiments for evaluation. Thanks.

@liuzhuang13
Copy link
Owner

In case you're still interested, we've released our Pytorch implementation here https://github.com/Eric-mingjie/network-slimming, which supports ResNet and DenseNet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants