该教程介绍如果在PaddleSlim提供的接口基础上快速自定义Filters
剪裁策略。
在PaddleSlim中,所有剪裁Filters
的Pruner
继承自基类FilterPruner
。FilterPruner
中自定义了一系列通用方法,用户只需要重载实现FilterPruner
的cal_mask
接口,cal_mask
接口定义如下:
def cal_mask(self, var_name, pruned_ratio, group):
raise NotImplemented()
cal_mask
接口接受的参数说明如下:
- var_name: 要剪裁的目标变量,一般为卷积层的权重参数的名称。在Paddle中,卷积层的权重参数格式为
[output_channel, input_channel, kernel_size, kernel_size]
,其中,output_channel
为当前卷积层的输出通道数,input_channel
为当前卷积层的输入通道数,kernel_size
为卷积核大小。 - pruned_ratio: 对名称为
var_name
的变量的剪裁率。 - group: 与待裁目标变量相关的所有变量的信息。
如图1-1所示,在给定模型中有两个卷积层,第一个卷积层有3个filters
,第二个卷积层有2个filters
。如果删除第一个卷积绿色的filter
,第一个卷积的输出特征图的通道数也会减1,同时需要删掉第二个卷积层绿色的kernels
。如上所述的两个卷积共同组成一个group,表示如下:
group = {
"conv_1.weight":{
"pruned_dims": [0],
"layer": conv_layer_1,
"var": var_instance_1,
"value": var_value_1,
},
"conv_2.weight":{
"pruned_dims": [1],
"layer": conv_layer_2,
"var": var_instance_2,
"value": var_value_2,
}
}
在上述表示group
的数据结构示例中,conv_1.weight
为第一个卷积权重参数的名称,其对应的value也是一个dict实例,存放了当前参数的一些信息,包括:
- pruned_dims: 类型为
list<int>
,表示当前参数在哪些维度上被裁。 - layer: 类型为paddle.nn.Layer, 表示当前参数所在
Layer
。 - var: 类型为paddle.Tensor, 表示当前参数对应的实例。
- value: 类型为numpy.array类型,待裁参数所存的具体数值,方便开发者使用。
图1-2为更复杂的情况,其中,Add
操作的所有输入的通道数需要保持一致,Concat
操作的输出通道数的调整可能会影响到所有输入的通道数,因此group
中可能包含多个卷积的参数或变量,可以是:卷积权重、卷积bias、batch norm
相关参数等。
import paddle
from paddle.vision.models import mobilenet_v1
net = mobilenet_v1(pretrained=False)
paddle.summary(net, (1, 3, 32, 32))
该小节参考L1NormFilterPruner
实现L2NormFilterPruner
,方式为集成FIlterPruner
并重载cal_mask
接口。代码如下所示:
import numpy as np
from paddleslim.dygraph import FilterPruner
class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None):
super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt)
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask.reshape(mask_shape)
如上述代码所示,我们重载了FilterPruner
基类的cal_mask
方法,并在L1NormFilterPruner
代码基础上,修改了计算通道重要性的语句,将其修改为了计算L2Norm的逻辑:
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
接下来定义一个L2NormFilterPruner
对象,并调用prune_var
方法对单个卷积层进行剪裁,prune_var
方法继承自FilterPruner
,开发者不用再重载实现。
按以下代码调用prune_var
方法后,参数名称为conv2d_0.w_0
的卷积层会被裁掉50%的filters
,与之相关关联的后续卷积和BatchNorm
相关的参数也会被剪裁。prune_var
不仅会对待裁模型进行inplace
的裁剪,还会返回保存裁剪详细信息的PruningPlan
对象,用户可以直接打印PruningPlan
对象内容。
最后,可以通过调用Pruner
的restore
方法,将已被裁剪的模型恢复到初始状态。
pruner = L2NormFilterPruner(net, [1, 3, 32, 32])
plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5)
print(plan)
pruner.restore()
参考:Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration
如图4-1所示,传统基于Norm统计方法的filter重要性评估方式的有效性取决于卷积层权重数值的分布,比较理想的分布式要满足两个条件:
- 偏差(deviation)要大
- 最小值要小(图4-1中v1)
满足上述条件后,我们才能裁掉更多Norm统计值较小的参数,如图4-1中红色部分所示。
图 4-1而现实中的模型的权重分布如图4-2中绿色分布所示,总是有较小的偏差或较大的最小值。
图 4-2考虑到上述传统方法的缺点,FPGM则用filter之间的几何距离来表示重要性,其遵循的原则就是:几何距离比较近的filters,作用也相近。 如图4-3所示,有3个filters,将各个filter展开为向量,并两两计算几何距离。其中,绿色filter的重要性得分就是它到其它两个filter的距离和,即0.7071+0.5831=1.2902。同理算出另外两个filters的得分,绿色filter得分最高,其重要性最高。
图 4-3以下代码通过继承FilterPruner
并重载cal_mask
实现了FPGMFilterPruner
,其中,get_distance_sum
用于计算第out_idx
个filter的重要性。
import numpy as np
from paddleslim.dygraph import FilterPruner
class FPGMFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None):
super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt)
def cal_mask(self, pruned_ratio, collection):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
if _detail.axis == 1:
_groups = _detail.op.attr('groups')
if _groups is not None and _groups > 1:
groups = _groups
break
dist_sum_list = []
for out_i in range(value.shape[0]):
dist_sum = self.get_distance_sum(value, out_i)
dist_sum_list.append(dist_sum)
scores = np.array(dist_sum_list)
if groups > 1:
scores = scores.reshape([groups, -1])
scores = np.mean(scores, axis=1)
sorted_idx = scores.argsort()
pruned_num = int(round(len(sorted_idx) * pruned_ratio))
pruned_idx = sorted_idx[:pruned_num]
mask_shape = [value.shape[pruned_axis]]
mask = np.ones(mask_shape, dtype="int32")
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
return mask.reshape(mask_shape)
def get_distance_sum(self, value, out_idx):
w = value.view()
w.shape = value.shape[0], np.product(value.shape[1:])
selected_filter = np.tile(w[out_idx], (w.shape[0], 1))
x = w - selected_filter
x = np.sqrt(np.sum(x * x, -1))
return x.sum()
接下来声明一个FPGMFilterPruner对象进行验证:
pruner = FPGMFilterPruner(net, [1, 3, 32, 32])
plan = pruner.prune_var("conv2d_0.w_0", 0, 0.5)
print(plan)
pruner.restore()
在第3节和第4节,开发者自定义实现的L2NormFilterPruner
和FPGMFilterPruner
也继承了FilterPruner
的敏感度计算方法sensitive
和剪裁方法sensitive_prune
。
import paddle.vision.transforms as T
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", backend="cv2",transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(mode="test", backend="cv2",transform=transform)
from paddle.static import InputSpec as Input
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
inputs = [Input([None, 3, 32, 32], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
net = mobilenet_v1(pretrained=False)
model = paddle.Model(net, inputs, labels)
model.prepare(
optimizer,
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(result)
pruner = FPGMFilterPruner(net, [1, 3, 32, 32], opt=optimizer)
def eval_fn():
result = model.evaluate(
val_dataset,
batch_size=128)
return result['acc_top1']
sen = pruner.sensitive(eval_func=eval_fn, sen_file="./fpgm_sen.pickle")
print(sen)
from paddleslim.analysis import dygraph_flops
flops = dygraph_flops(net, [1, 3, 32, 32])
print(f"FLOPs before pruning: {flops}")
plan = pruner.sensitive_prune(0.4, skip_vars=["conv2d_26.w_0"])
flops = dygraph_flops(net, [1, 3, 32, 32])
print(f"FLOPs after pruning: {flops}")
print(f"Pruned FLOPs: {round(plan.pruned_flops*100, 2)}%")
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"before fine-tuning: {result}")
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
result = model.evaluate(val_dataset,batch_size=128, log_freq=10)
print(f"after fine-tuning: {result}")