【pytorch源码剖析系列】模型搭建
创始人
2025-05-31 00:14:57
前言:本次内容不只是让你学会如何快速搭建网络骨架,还要深入理解pytorch的相关组件,让你能够实现自定义网络算子并将他们组合在一起。

思考:为什么不用numpy自定义模型并训练呢?
没有办法搭建复杂的网络结构
梯度需要自己计算,网络复杂时,梯度计算并进行反向传播,参数更新是非常复杂的
计算没有GPU的加持,速度会慢

问题:为什么要在每次反向传播后,对梯度进行清零

根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了。

其实这里还可以补充的一点是,如果不是每一个batch就清除掉原有的梯度,而是比如说两个batch再清除掉梯度,这是一种变相提高batch_size的方法,对于计算机硬件不行,但是batch_size可能需要设高的领域比较适合,比如目标检测模型的训练。

pytorch模型构建的四种方法

基本的模型网络结构定义框架

首先明确的一点是,自定义的网络结构一定是一个类,并且继承于nn.Module。
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()passdef forward(self, x):pass
  • __init__(self):这个初始化函数,是所有类中都会出现的一个初始化方法,在所有类实例化时最先执行的一个方法,一般会定义一些变量和实例化一些对象,比如卷积,池化等算子,方便后续调用。

  • forward(self, x):这个函数执行的是输入一个张量x,然后自定义顺序调用__init__方法中自定义的实例化对象,得到输出,为什么实例化后模型类后,在对象后加上括号,并输入参数x后就可以执行forward里的内容呢?原因是我们在类里继承了nn.module类,nn.module类里包含__call__方法,在之前的文章里提到过python的这个magic method,__call__()方法的作用其实是把一个类的实例化对象变成了可调用对象,也就是说把一个类的实例化对象变成了可调用对象,只要类里实现了__call__()方法就行。

注意:super(Net, self).__init__()表示初始化父类的初始化方法,此动作发生在实例化自定义类的时候。
model = Net() #实例化模型对象
ouput = model(x) #使得实例化对象变成了可调用对象,主要是因为自定义类中继承的nn.module类中实现了__call__方法,
#在该方法中又调用了这个forward()方法。

代码展示(nn.module中__call__方法调用了self.forward()方法)

  • 第一种方法:朴素

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net1(torch.nn.Module):def __init__(self):super(Net1, self).__init__()#Conv2d也是一个类,并继承于nn.module(类中实现了__call__方法,该方法中调用了forward()方法),#所以不必惊讶为何在forward()中可以直接以实例化对象名+()的形式调用self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)#同上self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)self.dense2 = torch.nn.Linear(128, 10)def forward(self, x):#(2,3,7,7) --> (2,32,3,3)x = F.max_pool2d(F.relu(self.conv1(x)), 2)#(2,32,3,3) --> (2, 32*3*3)x = x.view(x.size(0), -1)#(2, 32*3*3) --> (2,128)x = F.relu(self.dense1(x))#(2, 128) --> (2,10)x = self.dense2(x)return xprint("Method 1:")
model1 = Net1()
print(model1)
#在pytorch中张量的定义顺序为(batch_szie, channel, height, width)
x = torch.rand((2,3,7,7))
output = model1(x)
print(output.size())

输出结果:

  • 第二种方法:利用nn.Sequential

首先我们肯定要了解nn.Sequential()是个啥,是个类,同样继承了nn.Module类,自定义的多个算子可以以tuple的形式传入nn.Sequential()中,通过枚举这个tuple,将序号和算子加入到OrderedDict中,OrderedDict是一个有序字典,可以将字典进行排序。

nn.module中的add_module()方法

self._modules是一个有序字典

知识补充:OrderedDict
Python中的字典对象可以以“键:值”的方式存取数据。OrderedDict是它的一个子类,实现了对字典对象中元素的排序。不过现在默认的字典也是有序的,至于为什么pytorch还用着OrderedDict,因为要和旧版本保持耦合。

只使用nn.Sequential

import torchclass Net2(torch.nn.Module):def __init__(self):super(Net2, self).__init__()self.conv = torch.nn.Sequential(torch.nn.Conv2d(3, 32, 3, 1, 1),torch.nn.ReLU(),torch.nn.MaxPool2d(2))self.dense = torch.nn.Sequential(torch.nn.Linear(32 * 3 * 3, 128),torch.nn.ReLU(),torch.nn.Linear(128, 10))def forward(self, x):conv_out = self.conv(x)res = conv_out.view(conv_out.size(0), -1)out = self.dense(res)return outprint("Method 2:")
model1 = Net2()
print(model1)
x = torch.rand((2, 3, 7, 7))
output = model1(x)
print(output.size())

结果展示

使用nn.Sequential和OrderedDict结合的方式

优点:可以方便的为每一层算子命名,而不只是简单的序号

import torch
from collections import OrderedDictclass Net2(torch.nn.Module):def __init__(self):super(Net2, self).__init__()self.conv = torch.nn.Sequential(OrderedDict([('conv1',torch.nn.Conv2d(3, 32, 3, 1, 1)),('relu1',torch.nn.ReLU()),('pooling1', torch.nn.MaxPool2d(2))]))self.dense = torch.nn.Sequential(OrderedDict([('linear1',torch.nn.Linear(32 * 3 * 3, 128)),('relu1',torch.nn.ReLU()),('linear2',torch.nn.Linear(128, 10))]))def forward(self, x):conv_out = self.conv(x)res = conv_out.view(conv_out.size(0), -1)out = self.dense(res)return outmodel = Net2()
print(model)
x = torch.rand((2,3,7,7))
output = model(x)
print(output.size())

结果展示

  • 第三种方法:使用ModuleList

class net2(nn.Module):def __init__(self):super(net2, self).__init__()self.modlist = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])#这里若按照这种写法则会报NotImplementedError错#def forward(self, x):#    return self.modlist(x)#注意:只能按照下面利用for循环的方式def forward(self, x):for m in self.modlist:x = m(x)return xinput = torch.randn(16, 1, 20, 20)
net2 = net2()
print(net2(input).shape)
#torch.Size([16, 64, 12, 12])
  • 第四种方法:使用ModuleDict

ModuleDict和ModuleList的作用类似,只是ModuleDict能够更方便地为神经网络的层添加名称。

et = nn.ModuleDict({'linear': nn.Linear(784, 256),'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)#结果
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict((act): ReLU()(linear): Linear(in_features=784, out_features=256, bias=True)(output): Linear(in_features=256, out_features=10, bias=True)
)

几种方法的适用场景

  1. Sequential适用于快速验证结果,因为已经明确了要用哪些层,直接写一下就好了,不需要同时写__init__和forward;

  1. ModuleList和ModuleDict在某个完全相同的层需要重复出现多次时,非常方便实现,可以”一行顶多行“;当我们需要之前层的信息的时候,比如 ResNets 中的残差计算,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList/ModuleDict 比较方便。

nn.Sequential与nn.ModuleList的区别

  1. nn.Sequential内部实现了forward函数,因此可以不用写forward函数。而nn.ModuleList则没有实现内部forward函数。一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

  1. nn.Sequential可以使用OrderedDict对每层进行命名。

  1. nn.Sequential里面的模块按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。网络的执行顺序是根据 forward 函数来决定的。若将forward函数中几行代码互换,使输入输出之间的大小不一致,则程序会报错。

class net3(nn.Module):def __init__(self):super(net3, self).__init__()self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])def forward(self, x):x = self.linears[2](x)x = self.linears[0](x)x = self.linears[1](x)return xnet3 = net3()
print(net3)
#net3(
#  (linears): ModuleList(
#    (0): Linear(in_features=10, out_features=20, bias=True)
#    (1): Linear(in_features=20, out_features=30, bias=True)
#    (2): Linear(in_features=5, out_features=10, bias=True)
#  )
#)input = torch.randn(32, 5)
print(net3(input).shape)
#torch.Size([32, 30])
  1. 有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们。

pytorch构建模型的四种方式

参考链接:
pytorch模型搭建的四种方式
为什么每次反向传播后都要对梯度清零
pytorch模型定义的方式
PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景
python OrderedDict用法

相关内容

热门资讯

04-centos7的安装和系... 按tab 开始调整网卡名称,从net.if....开始 之后按enter进入安装 ...
男子撞豪车转头炫耀(撞豪车案例... 本篇文章极速百科给大家谈谈男子撞豪车转头炫耀,以及撞豪车案例对应的知识点,希望对各位有所帮助,不要忘...
跌停什么意思(经常集合竞价跌停... 今天给各位分享跌停什么意思的知识,其中也会对经常集合竞价跌停什么意思进行解释,如果能碰巧解决你现在面...
宝马3系一年养车费用大概是多少... 本篇文章极速百科给大家谈谈宝马3系一年养车费用大概是多少?,以及宝马3系一年的养车费用对应的知识点,...
虚词有哪些虚词包括哪些类别(虚... 本篇文章极速百科给大家谈谈虚词有哪些虚词包括哪些类别,以及虚词有哪些虚词包括哪些类别的词语对应的知识...
随机过程 Poisson 过程 文章目录随机过程 Poisson 过程基本概念与 Poisson 过程相联系的若干分布XnX_nXn...
【C++初阶】六、模板初阶(函... 文章目录泛型编程函数模板函数模板的概念函数模板的格式函数模板的原理函数模板的实例化模板参数的匹配原则...
每个开发人员都需要掌握的10 ... SQL 是一种非常常见但功能强大的工具,它可以帮助从任何数据库中提取、转换和加载数据。...
每年的918都有哪些地方会拉响... 本篇文章极速百科给大家谈谈每年的918都有哪些地方会拉响防空警报?,以及918那些地方有拉防空警报对...
慢直播是什么意思(慢直播怎么赚... 本篇文章极速百科给大家谈谈慢直播是什么意思,以及慢直播怎么赚钱对应的知识点,希望对各位有所帮助,不要...
09款福特嘉年华自动挡1.5的... 本篇文章极速百科给大家谈谈09款福特嘉年华自动挡1.5的换一套机脚垫要多少钱啊,以及对应的知识点,希...
什么是有机奶(什么是有机奶和纯... 本篇文章极速百科给大家谈谈什么是有机奶,以及什么是有机奶和纯牛奶的区别对应的知识点,希望对各位有所帮...
[图神经网络]图卷积神经网络-... 一、消息传递         由于图具有“变换不变性”(即图的空间结构改变不会影响图的性状)...
Learning C++ No... 引言: 北京时间:2023/3/18/21:47,周末&#...
数据结构第一二章笔记 仅仅是自己学习记的一些笔记。1.一些零碎的知识时间复杂度:找出哪一条语句执行的次数最多...
cmake-下载和安装 1.下载和安装 cmake:https://cmake.org/download/ (...
JDBC教程下篇 二、SQL注入 2.1 什么是SQL注入 用户输入的数据中有SQL关键词,导致在执行SQL语句时出...
钛动科技斩获 2022 Tik... 近日,第三方平台FastData研究院正式发布了行业报告《2022年度TikTok生态发展白皮书》,...
头部险企如何打造低代码数据集市... 保险业的金融科技建设正在按下快进键,从最新发布的“2022 保险科技创新指数报告”来看...
刷题之-剑指 Offer II... 最近很久没刷题了,面试官给了这么一道题,只给10分钟时间,...
Nginx代理后获取客户端真实... 1、场景 在项目实际应用中,我们可能会需要获取到用户也就是客户端的真实IP地址...
第十三届蓝桥杯省赛 pytho... 文章目录前言主要内容🦞试题 A:排列字母思路代码🦞试题...
阿里春招-2023.3.15-... 极差三元组计数 Problem Description 给定一个数组,请你计算有多少个...
电压放大器在钢筋剥离损伤识别试...   实验名称:钢筋剥离损伤识别试验  研究方向:无损检测  测试目的&#...
MOCO论文前几段精读 MoCo MoCo是CVPR 2020的最佳论文提名,算是视觉领域里,使...
【lua初级篇】基础知识和开发... 文章介绍 文章介绍 简述 工具安装配置和下载 快速看基础知识 一些常用的关键字一览 数据类型 tab...
Yuv422、Nv12转C#B... 1.1、Nv12转Bitmapint w = 1920;int h = 1080;i...
Linux互斥量和信号量的区别... 互斥量和信号量的区别 1.互斥量用于线程的互斥: 互斥:加锁解锁,是指某...
Git 和 GitHub 超入... 1.解决行结束符问题 需要在你的仓库中添加一个.gitattributes文件,标记正...
基于C++的AI五子棋游戏项目... 项目资源下载 基于C++的AI五子棋游戏项目源码压缩包下载地址基于C+...