深度学习模型剪枝的实现
创始人
2025-05-28 21:15:13

    对于深度学习来说,比较复杂的模型往往有着不错的识别效果,但是复杂的模型往往对算力要求也比较高,在一些对于实时性要求比较高或者算力比较小的应用场景中,这时复杂的模型往往不能很好达到预期效果,这时候就要进行模型的剪枝,提高模型的运算速度。 剪枝也就是将这个参数置为0,消除这些节点与后面的联系,从而降低运算量,本文主要基于对于模型剪枝的实战展开。

本文参考:深度学习之模型压缩(剪枝、量化)_深度学习模型压缩_CV算法恩仇录的博客-CSDN博客

目录

模型构造

必要的函数解释

module.named_parameters()

module.named_buffers()

model.state_dict().keys()

module._forward_pre_hooks

单层剪枝

连续单层剪枝

全局剪枝

自定义剪枝 

模型构造

    本部分主要是先说明下面示例会用到的模型,就是我们大名鼎鼎的LeNet模型,当然其实其他模型也可以,只要是一个有着基本构造的网络都是可以的。

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸self.conv1 = nn.Conv2d(1, 6, 3)# self.conv1 = nn.Conv2d(2, 3, 3)self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)

必要的函数解释

    首先需要简单介绍一下下面可能会频繁出现的一些函数,如果不解释,可能就会看了很迷,我自己去搜索也没有知找到非常直观的解释,所以就按照我的理解线捋一捋。也可以先跳过去,之后遇到了再回来看。解释的都是基于卷积层,其他层类似。

 module.named_parameters()

    在上面的的模型构造中已经说明module就是表示模型的第一层卷积层(当然任意一层都是可以的),针对于卷积层,这个函数得到的就是卷积核参数的情况以及一些信息。

    这里我一开始比较迷惑的就是为什么Conv2d(1, 6, 3)的卷积层有许多3*3的卷积核。假设卷积层如上所示是Conv2d(1, 6, 3),代表输入1通道,输出6通道,卷积核大小3,那么其中的参数最小元素就是3*3的卷积核,因为要输出6个通道,那么每个输出通道都需要和1个输入通道卷积,得到一个通道输出,所以就有6*1个3*3的卷积核;假如输入是2通道,那么每个输出通道都需要和2个输入通道卷积,每个输入通道都需要一个卷积核,然后得到一个输出,这时候就会有6*2个3*3的卷积核。

module.named_buffers()

    这个函数表示掩码缓冲区。因为后面剪枝是针对于卷积核参数,需要标记哪些位置的参数是要被删除的,而这个缓冲区的数字是和卷积核参数一一对应的,要是这个参数被剪了,那么这个位置标记为0,否则是1,最后和参数矩阵相乘,被剪掉的位置参数就变为了0。

model.state_dict().keys()

    这个输出的是当前的状态列表,可能是需要用到的一些参数,在剪枝之前这个里面就是单独的参数,在剪枝之后就变成了参数备份和掩码矩阵,具体的作用也不是很懂。

module._forward_pre_hooks

这个参数是一个列表,里面就记录了堆某个层使用的算法记录,比如L1正则化这种。

下面介绍四种剪枝方式

1.单层剪枝(对于特定的卷积层或某进行剪枝)

2.连续单层剪枝(对多层进行单层剪枝)

3.全局剪枝(对全局进行剪枝)

4.自定义剪枝(自定义剪枝规则)

单层剪枝

    首先是对于某一个特定的层进行剪枝,利用prune.random_unstructured()函数,里面写入参数剪枝模型的特定层,比如卷积层,修剪对象也就是对权重weight还是偏置bias修剪,还有修剪比例,然后就会按照你的要求剪枝。假设修剪的是权重weight,缓冲区buffer里面放的就是掩码,标记了哪些位置的参数要被剪除,这些位置为0,否则为1。

module = model.conv1
print("---修剪前的状态字典")
print(model.state_dict().keys())  # 打印修剪前的状态字典,发现有weight
print("---修剪前的参数")
print(list(module.named_parameters()))
print("---修剪前的缓冲区")
print(list(module.named_buffers()))
prune.random_unstructured(module, name="weight", amount=0.3)  # 对参数修剪
print("*" * 50)
print("---修剪后的状态字典")
print(model.state_dict().keys())  # 打印修剪前的状态字典,发现多出了 orig 和 mask
print("---修剪后的参数")
print(list(module.named_parameters()))  # 实际上还没有变,下面会解释
print("---修剪后的缓冲区")
print(list(module.named_buffers()))  # 这个就是掩码
print("---修剪算法")
print(module._forward_pre_hooks)  # 这里里面存放的每个元素是一个算法

    可以从状态列表里面看出修建前后的区别,就是weight变为了weight_orig和weight_mask,mask实际上标记了哪些位置是要剪除的,这个数据就是在缓冲区,所以一开始剪枝前是空而剪枝后有了数据。weight_orig就是做了个备份,还是原来的weight,到时候掩码和weight_orig两个相乘就是剪枝后的结果。可以从修剪后的参数中看出,实际上和修剪前是一样的,那么如何将参数变为修剪后,就要用到remove函数,remove也就类似于确定修剪的按钮,执行了之后就会把缓冲区的mask删掉,并且将要剪除的参数变为0,这个过程不可逆,执行之后,要是没有额外备份,那么参数就会被永久改变。(连在上面代码后面)

prune.remove(module, 'weight')
print("---执行remove后的参数")
print(list(module.named_parameters()))  # 此时参数变化

连续单层剪枝

    连续单层剪枝其实类似于单层剪枝,唯一的不同就是上面对一个卷积层剪枝,现在我们可以利用一个循环,将所有卷积层和全连接层进行剪枝操作,单个循环内其实还是单层剪枝。 

print(dict(model.named_buffers()).keys())  # 打印缓冲区
print(model.state_dict().keys())  # 打印初始模型的所有状态字典
print(dict(model.named_buffers()).keys())  # 打印初始模型的mask buffers张量字典名称,发现此时为空(因为还没剪枝)
for name, module in model.named_modules():# 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝if isinstance(module, torch.nn.Conv2d): # 比较第一个是不是第二个表示的类,这里就是判断是不是卷积层prune.l1_unstructured(module, name="weight", amount=0.2)# 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝elif isinstance(module, torch.nn.Linear):prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys()) # 打印缓冲区
print(model.state_dict().keys())  # 打印多参数模块剪枝后模型的所有状态字典名称

    可以发现缓冲区内多出了每个层weight的mask,状态字典的weight也变成了weight_orig和mask。

全局剪枝

上面两种都是对于特定层剪枝,而全局剪枝则是面向整个模型,在整个模型中剪除多少比例的参数,从而缩减模型。(此处代码基本是搬运的,感谢文首提到的大佬)

model = LeNet().to(device=device)
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight'))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)
# 统计每个层被剪枝的数量百分比(也就是统计等于0的数字占总数的比例)
print("Sparsity in conv1.weight: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0))/ float(model.conv1.weight.nelement())))print("Sparsity in conv2.weight: {:.2f}%".format(100. * float(torch.sum(model.conv2.weight == 0))/ float(model.conv2.weight.nelement())))print("Sparsity in fc1.weight: {:.2f}%".format(100. * float(torch.sum(model.fc1.weight == 0))/ float(model.fc1.weight.nelement())))print("Sparsity in fc2.weight: {:.2f}%".format(100. * float(torch.sum(model.fc2.weight == 0))/ float(model.fc2.weight.nelement())))print("Sparsity in fc3.weight: {:.2f}%".format(100. * float(torch.sum(model.fc3.weight == 0))/ float(model.fc3.weight.nelement())))print("Global sparsity: {:.2f}%".format(100. * float(torch.sum(model.conv1.weight == 0)+ torch.sum(model.conv2.weight == 0)+ torch.sum(model.fc1.weight == 0)+ torch.sum(model.fc2.weight == 0)+ torch.sum(model.fc3.weight == 0))/ float(model.conv1.weight.nelement()+ model.conv2.weight.nelement()+ model.fc1.weight.nelement()+ model.fc2.weight.nelement()+ model.fc3.weight.nelement())))

    运行之后就可以发现每一层都不同程度被剪除了参数,计算方式就是计算mask层中0所占的比例。

自定义剪枝

    自定义剪枝的自定义主要是体现在剪枝方法上面,比如参数接近于0或者相对很小那么可能贡献很小,那么这时候就可以考虑剪除,对模型也不会造成很大影响。下面的示例采用隔位剪枝的方式,也就是隔一个剪一个,当然这是可以改掉的。(因为参考的是这么写的) 

class myself_pruning_method(prune.BasePruningMethod):PRUNING_TYPE = "unstructured"# 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数def compute_mask(self, t, default_mask):mask = default_mask.clone()# 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉# 当然可以自己定义mask.view(-1)[::2] = 0return mask# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def myself_unstructured_pruning(module, name):myself_pruning_method.apply(module, name)return module# 下面开始剪枝
# 实例化模型类
model = LeNet().to(device=device)start = time.time()  # 计时
# 调用自定义剪枝方法的函数, 对model中的第三个全连接层fc3中的偏置bias执行自定义剪枝
myself_unstructured_pruning(model.fc3, name="bias")# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc3.bias_mask)# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')

相关内容

热门资讯

联动云租一天多少钱(联动云租一... 本篇文章极速百科给大家谈谈联动云租一天多少钱,以及联动云租一天怎么划算对应的知识点,希望对各位有所帮...
飞机托运收费(飞机托运收费多少... 本篇文章极速百科给大家谈谈飞机托运收费,以及飞机托运收费多少钱一公斤对应的知识点,希望对各位有所帮助...
挡泥板(挡泥板是什么意思) 挡... 本篇文章极速百科给大家谈谈挡泥板,以及挡泥板是什么意思对应的知识点,希望对各位有所帮助,不要忘了收藏...
滴滴专车官网(滴滴专车司机网站... 今天给各位分享滴滴专车官网的知识,其中也会对滴滴专车司机网站进行解释,如果能碰巧解决你现在面临的问题...
路特斯跑车(路特斯跑车多少钱一... 今天给各位分享路特斯跑车的知识,其中也会对路特斯跑车多少钱一辆2023款进行解释,如果能碰巧解决你现...
丰田致享新车报价(丰田致享20... 今天给各位分享丰田致享新车报价的知识,其中也会对丰田致享2021款报价进行解释,如果能碰巧解决你现在...
聊城到潍坊的汽车(聊城到潍坊的... 本篇文章极速百科给大家谈谈聊城到潍坊的汽车,以及聊城到潍坊的汽车票价多少对应的知识点,希望对各位有所...
没有身份证怎么买票(没有身份证... 今天给各位分享没有身份证怎么买票的知识,其中也会对没有身份证怎么买票进行解释,如果能碰巧解决你现在面...
2018科目三灯光详细表(20... 本篇文章极速百科给大家谈谈2018科目三灯光详细表,以及2018科目三最新模拟灯光考试20组不重复完...
五菱之光v(五菱之光v和五菱之... 今天给各位分享五菱之光v的知识,其中也会对五菱之光v和五菱之光有什么区别进行解释,如果能碰巧解决你现...
摩托车怠速(摩托车怠速多少转正... 今天给各位分享摩托车怠速的知识,其中也会对摩托车怠速多少转正常进行解释,如果能碰巧解决你现在面临的问...
武汉到西安(武汉到西安火车时刻... 今天给各位分享武汉到西安的知识,其中也会对武汉到西安火车时刻表查询进行解释,如果能碰巧解决你现在面临...
五菱之光v图片(五菱之光v新车... 今天给各位分享五菱之光v图片的知识,其中也会对五菱之光v新车报价进行解释,如果能碰巧解决你现在面临的...
郑州到重庆火车(郑州到重庆火车... 本篇文章极速百科给大家谈谈郑州到重庆火车,以及郑州到重庆火车多少钱一张对应的知识点,希望对各位有所帮...
学生证优惠区间(学生证优惠区间... 今天给各位分享学生证优惠区间的知识,其中也会对学生证优惠区间没有盖章进行解释,如果能碰巧解决你现在面...
武汉到合肥(武汉到合肥多少公里... 今天给各位分享武汉到合肥的知识,其中也会对武汉到合肥多少公里进行解释,如果能碰巧解决你现在面临的问题...
软座座位分布图(k8412软座... 本篇文章极速百科给大家谈谈软座座位分布图,以及k8412软座座位分布图对应的知识点,希望对各位有所帮...
长安逸动dt(长安逸动dt空调... 本篇文章极速百科给大家谈谈长安逸动dt,以及长安逸动dt空调滤芯拆卸教程对应的知识点,希望对各位有所...
西安到达州(西安到达州火车时刻... 本篇文章极速百科给大家谈谈西安到达州,以及西安到达州火车时刻表查询对应的知识点,希望对各位有所帮助,...
野马蝰蛇(野马蝰蛇gt500图... 本篇文章极速百科给大家谈谈野马蝰蛇,以及野马蝰蛇gt500图片对应的知识点,希望对各位有所帮助,不要...
高速obu是什么意思(收费站o... 今天给各位分享高速obu是什么意思的知识,其中也会对收费站obu是什么意思进行解释,如果能碰巧解决你...
西安北站在哪(西安北站在哪进站... 今天给各位分享西安北站在哪的知识,其中也会对西安北站在哪进站进行解释,如果能碰巧解决你现在面临的问题...
汽车搭电一次多少钱(汽车搭电大... 今天给各位分享汽车搭电一次多少钱的知识,其中也会对汽车搭电大概多少钱进行解释,如果能碰巧解决你现在面...
宝马跑车敞篷价格(宝马跑车敞篷... 本篇文章极速百科给大家谈谈宝马跑车敞篷价格,以及宝马跑车敞篷价格图片对应的知识点,希望对各位有所帮助...
cbr650r(cbr650r... 本篇文章极速百科给大家谈谈cbr650r,以及cbr650r座高对应的知识点,希望对各位有所帮助,不...
在哪买机票最便宜(在哪买机票最... 今天给各位分享在哪买机票最便宜的知识,其中也会对在哪买机票最便宜票进行解释,如果能碰巧解决你现在面临...
etc办理点(etc办理点节假... 今天给各位分享etc办理点的知识,其中也会对etc办理点节假日休息吗进行解释,如果能碰巧解决你现在面...
宝马1181报价及图片(宝马1... 今天给各位分享宝马1181报价及图片的知识,其中也会对宝马1181报价及图片及价格进行解释,如果能碰...
限行处罚扣分吗(限行被扣分吗)... 本篇文章极速百科给大家谈谈限行处罚扣分吗,以及限行被扣分吗对应的知识点,希望对各位有所帮助,不要忘了...
车车(车车车念什么) 车车 车... 今天给各位分享车车的知识,其中也会对车车车念什么进行解释,如果能碰巧解决你现在面临的问题,别忘了关注...