DETR网络模型构建
创始人
2025-05-31 09:35:42

在这里插入图片描述

这篇文章主要为记录DETR模型的构建过程
首先明确DETR模型的搭建顺序:首先是backbone的搭建,使用的是resnet50,随后是Transformer模型的构建,包含编码器的构建与解码器的构建,完成后则是整个DETR模型的构建
构建代码在detr.py文件中

	# 搭建主干网络backbone = build_backbone(args)# 搭建transfoemertransformer = build_transformer(args)# 搭建DETR模型model = DETR(backbone,transformer,num_classes=num_classes,num_queries=args.num_queries,aux_loss=args.aux_loss,)

我们来沿着这个代码逐步了解其构造

骨干网络构建

进入 backbone.py 的 build_backbone() 方法

def build_backbone(args):#搭建位置编码器position_embedding = build_position_encoding(args)#args.lr_backbone默认为1e-5,则train_backbone默认为true,通过设置backbone的lr来设置是否训练网络时# 接收backbone的梯度从而让backbone也训练。train_backbone = args.lr_backbone > 0#是否需要记录backbone的每层输出return_interm_layers = args.masks#构建骨干网络 args.backbone默认为resnet50,args.dilatyion默认为false。backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)# 将backbone和位置编码器集合在一起放在一个model里model = Joiner(backbone, position_embedding)# 设置model的输出通道数model.num_channels = backbone.num_channelsreturn model

位置编码器构建

首先是位置编码器构造,进入position_encoding.py文件(注意位置编码最后是加上去的,故其前后维度不会发生变化。

def build_position_encoding(args):#args.hidden_dim    transformer的输入张量的channel数,位置编码和backbone的featuremap结合后需要输入到transformer中N_steps = args.hidden_dim // 2# 余弦编码方式,文章说采用正余弦函数,是根据归纳偏置和经验做出的选择if args.position_embedding in ('v2', 'sine'):# TODO find a better way of exposing other arguments#PositionEmbeddingSine(N_steps, normalize=True):正余弦编码方式,这种方式是将各个位置的各个维度映射到角度上,# 因此有个scale,默认是2pi。position_embedding = PositionEmbeddingSine(N_steps, normalize=True)# 可学习的编码方式elif args.position_embedding in ('v3', 'learned'):position_embedding = PositionEmbeddingLearned(N_steps)else:raise ValueError(f"not supported {args.position_embedding}")return position_embedding

返回的位置编码信息如下:

在这里插入图片描述

resnet50网络构建

随后使用pytorch的model库进行resnet50网络的构建

def __init__(self, name: str,train_backbone: bool,return_interm_layers: bool,dilation: bool):# getattr(obj,name)获取obj中命名为name的组成。可以理解为获取obj.name# 获取torchvision.models中实现的resnet50网络结构# replace_stride_with_dilation 决定是否使用膨胀卷积;pretrained 是否使用预训练模型;# norm_layer 使用FrozenBatchNorm2d归一化方式backbone = getattr(torchvision.models, name)(replace_stride_with_dilation=[False, False, dilation],pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)#获得resnet50网络结构,并设置输出channels为2048,所以我们的backbone的输出则是# (batch, 2048, H // 32,W // 32),在父类BackboneBase.__init__中进行初始化。num_channels = 512 if name in ('resnet18', 'resnet34') else 2048super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

此时通过pytorch的model库获得的backbone为:其中已进行了相关参数的初始化

在这里插入图片描述
在最后其进入其父类进行进一步设置,逐层设置是否进行梯度更新,其构建的resnet模型中已是预训练模型,有参数。
通过下图可以看到backbone的主要结构,通道数

在这里插入图片描述
在这里插入图片描述

随后通过model = Joiner(backbone, position_embedding)将位置编码器与backbone构建到一个model中。并通过model.num_channels = backbone.num_channels设置新构成的model的通道数与backbone的通道数一致为2048。
Joiner构造后的结构如下:
在这里插入图片描述

Joiner((0): Backbone((layer1): Sequential( )(layer2): Sequential( )(layer3): Sequential( )(layer4): Sequential( ))(1): PositionEmbeddingSine()
)

详细结构如下:

Joiner((0): Backbone((body): IntermediateLayerGetter((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))))(1): PositionEmbeddingSine()
)

最终完成backbone的构建,随后则跳转回detr.py中继续transformer模型的构建

Transformer模型构建

跳转到 transformer.py 中执行 build_transformer() 方法

def build_transformer(args):return Transformer(d_model=args.hidden_dim,dropout=args.dropout,nhead=args.nheads,dim_feedforward=args.dim_feedforward,num_encoder_layers=args.enc_layers,num_decoder_layers=args.dec_layers,normalize_before=args.pre_norm,return_intermediate_dec=True,)

其对应参数值为:

在这里插入图片描述
在Transformer模型的构建中,分别包含编码器层与解码器层的构建,而由上可知分别设计6层编码器与6层解码器,其结构都是相同的,只需要构造成一个随后再复制即可。

![

Transformer结构图

在这里插入图片描述

  1. 在第一层,transformer = embedding + 位置编码(Positional Encoding) + encoder +
    decoder ;
  2. 在第二层,encoder = 多个EncoderLayer = 多个(Multi-Head-Attention + LayerNorm
    Residual连接 + FeedForwardNet);decoder = 多个DecoderLayer = 多个(Masked Multi-Head-Attention + encoder-decoder Multi-Head-Attention + LayerNorm + Residual连接 + FeedForwardNet)。其中,LayerNormalization和BatchNorm不同,BatchNorm是在一个batch里所有sample的同一个维度上计算mean std(均差与方差),而LayerNorm不考虑batch,是在一个sample的不同维度上计算mean std;
  3. 在第三层,Multi-Head-Attention = linear + scaled dot-product attention(单头注意力) + concat + linear;
  4. 在第四层, scaled dot-product attention = MatMul + Scale + Mask + SoftMax+MatMul。

编码器层的构造

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)

在这里插入图片描述

多头注意力构建

编码器中最重要的便是多头注意力的构建了,使用的是pytorch中封装好的方法构建。

self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

参数介绍

# 先决定参数
dims = 256 * 10 # 所有头总共需要的输入维度
heads = 10    # 单注意力头的总共个数
dropout_pro = 0.0 # 单注意力头  dropout_p: dropout的概率,当其为非零时执行dropout# 传入参数得到我们需要的多注意力头
layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)

在这里,输入维度为256,可以理解为想要获得的特征维度为256,在内部计算时,由于有8个头,则会将其均分为256/8=32,随后在输出是会将其再拼接回256。

多头注意力已经在pytorch中封装好了,我们直接调用即可:

def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super(MultiheadAttention, self).__init__()self.embed_dim = embed_dimself.kdim = kdim if kdim is not None else embed_dimself.vdim = vdim if vdim is not None else embed_dimself._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dimself.num_heads = num_headsself.dropout = dropoutself.batch_first = batch_firstself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"if not self._qkv_same_embed_dim:self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))self.register_parameter('in_proj_weight', None)else:self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))self.register_parameter('q_proj_weight', None)self.register_parameter('k_proj_weight', None)self.register_parameter('v_proj_weight', None)if bias:self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))else:self.register_parameter('in_proj_bias', None)self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)if add_bias_kv:self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))else:self.bias_k = self.bias_v = Noneself.add_zero_attn = add_zero_attnself._reset_parameters()

Q,K,V 等参数也是在此构造的,其维度也都为256。(注意按照源码来看,QKV可以存在维度不同的状态)

torch.empty是按照所给的形状形成对应的tensor,特点是填充的值还未初始化,类比torch.randn(标准正态分布),这就是一种初始化的方式。在PyTorch中,变量类型是tensor的话是无法修改值的,而Parameter()函数可以看作为一种类型转变函数,将不可改值的tensor转换为可训练可修改的模型参数,即与model.parameters绑定在一起,register_parameter的意思是是否将这个参数放到model.parameters,None的意思是没有这个参数。

随后进行权重值初始化,分别为Weights (2563,256),Bias为2563
这里实际上Weights是由Wq,Wk,Wv按序排列组合的,在进行运算时再分开。

多头注意力介绍

关于多头注意力机制,就是有多个单头注意力组成的,如下图单头注意力。

在这里插入图片描述

在这里插入图片描述

整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单头注意肯定是不够的 ]
于是最简单的思路,最优雅的方式就是将多个头横向拼接在一起,每次运算我同时提到多个特征,所以多头的样子如下:

在这里插入图片描述
因为是拼接而成的,所以每个单注意力头其实是各自输出各自的,所以会得到h个特征,把h个特征拼接起来,就成为了多注意力的输出特征。
那么想要在训练的时候使用,我们就需要给它喂入数据,也就是调用forward函数,完成前向传播这一动作。

编码器其他组件构建

在这里插入图片描述

Positional Encoding

位置编码,按照字面意思理解就是给输入的位置做个标记,简单理解比如你就给一个字在句子中的位置编码1,2,3,4这样下去,高级点的比如作者用的正余弦函数
在这里插入图片描述
其中pos表示字在句子中的位置,i指的词向量的维度。经过位置编码,相当于能够得到一个和输入维度完全一致的编码数组 Xpos ,当它叠加到原来的词嵌入上得到新的词嵌入。
在这里插入图片描述
此时的维度为:一个批次的句子数 X 一个句子的词数 X 一个词的嵌入维度

Add & Norm

这里主要做了两个操作

一个是残差连接(或者叫做短路连接),说得直白点就是把上一层的输入 X 和上一层的输出 SubLayer(X) 加起来 ,即 X+SubLayer(X)
,举例说明,比如在注意力机制前后的残差连接:
在这里插入图片描述
一个是LayerNormalization(作用是把神经网络中隐藏层归一为标准正态分布,加速收敛),具体操作是将每一行的每一个元素减去这行的均值, 再除以这行的标准差, 从而得到归一化后的数值。

Feedforward + Add & Norm

前馈网络也就是简单的两层线性映射再经过激活函数一下,比如
在这里插入图片描述
残差操作和层归一化同步骤

此时便构建出 Transformer中的一个encoder模块,经过1,2,3,4后得到的就是encode后的隐藏层表示,可以发现它的维度其实和输入是一致的!即:一个批次中句子数 X 一个句子的字数 X 字嵌入的维度。
构建的单层编码器层整体代码如下:

在这里插入图片描述

构建后的单层编码器层相关参数如下:

在这里插入图片描述
最终构建多层编码器:

class TransformerEncoder(nn.Module):def __init__(self, encoder_layer, num_layers, norm=None):super().__init__()self.layers = _get_clones(encoder_layer, num_layers)#克隆多个encoder_layerself.num_layers = num_layersself.norm = normdef forward(self, src,mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):output = srcfor layer in self.layers:output = layer(output, src_mask=mask,src_key_padding_mask=src_key_padding_mask, pos=pos)if self.norm is not None:output = self.norm(output)return output

克隆多个层代码

def _get_clones(module, N):return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

解码器层的构造

与编码器不同,解码器中的最开始时需要使用多头注意力进行计算self-attention。

def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False):super().__init__()self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)# Implementation of Feedforward modelself.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)self.activation = _get_activation_fn(activation)self.normalize_before = normalize_before

将上面初始化的组件连接起来:

def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):q = k = self.with_pos_embed(tgt, query_pos)#加入位置编码tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]#自注意力计算tgt = tgt + self.dropout1(tgt2)#残差连接tgt = self.norm1(tgt)#层归一化tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)#残次连接tgt = self.norm2(tgt)#层归一化tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))#feed and forwardtgt = tgt + self.dropout3(tgt2)#残次连接tgt = self.norm3(tgt)#normal归一化return tgt

关于其多头注意力的构建与encoder是相同的,值得注意的是,detr的解码层是没有mask的。此外query的数量也是固定的。
最终,进行Transformer模型的组建:

def forward(self, src, mask, query_embed, pos_embed):# src: transformer输入,mask:图像掩码, query_embed:decoder预测输入embed, pos_embed:位置编码# flatten NxCxHxW to HWxNxCbs, c, h, w = src.shape# 获取encoder输入src = src.flatten(2).permute(2, 0, 1)# 获取位置编码pos_embed = pos_embed.flatten(2).permute(2, 0, 1)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)# 获取输入掩码mask = mask.flatten(1)# torch.zeros_like:生成和括号内变量维度维度一致的全是零的内容。# tgt初始化,意义为初始化需要预测的目标。因为一开始不清楚需要什么样的目标,所以初始化为0,它会在decoder中# 不断被refine,但真正在学习的是query embedding,学习到的是整个数据集中目标物体的统计特征。而tgt在每一个epoch都会初始化。# tgt 也可以理解为上一层解码器的解码输出 shape=(100,N,256) 第一层的tgt=torch.zeros_like(query_embed) 为零矩阵,# query_pos 是可学习输出位置向量, 个人理解 解码器中的这个参数 全局共享 提供全局注意力 query_pos=(100,N,256)tgt = torch.zeros_like(query_embed)# 获取encoder输出memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)# 获取decoder输出,return_intermediate_dec为true时,得到decoder每一层的输出hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

完成整体Transformer的构建后,我们将backbone与Transformer一起构建为DETR模型。

DETR模型构建

再次跳转到detr.py文件,开始DETR模型的构建

model = DETR(backbone,transformer,num_classes=num_classes,num_queries=args.num_queries,aux_loss=args.aux_loss,)

DETR组合

关于DETR模型初始化组件代码:

super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_model  # transformer输出channel# decoder后再接一个全连接,输出分类结果self.class_embed = nn.Linear(hidden_dim, num_classes + 1)# 利用MLP对框进行回归self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)# decoder预测输入,每帧预测num_queries个目标self.query_embed = nn.Embedding(num_queries, hidden_dim)# transformer输入前处理,backbone得到的是num_channels(2048)维度的输出,需要1*1的卷积降维到hidden_dimself.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_loss

再DETR中,其进行了其他组件的初始化,包含分类头(使用Linear实现,输出为7个维度num_class+1) 以及通过MLP实现回归给出预测框。输出4个值分别为x,y,w,h

在这里插入图片描述

至此,模型构建便完成了。

相关内容

热门资讯

labview数据类型转换字符... wx供重浩:创享日记 对话框发送:labview转换 获取完整无水印报告...
【分享】国内如何使用chatG... 上周,OpenAI宣布正式发布多模态预训练大模型GPT-4,其强大的能力...
软件智能:aaas系统中AI的... 概要(内容概述) <同一>将设计目标确定为“软件智能”的aaas中,AI的任务和AI能...
早日康复祝福语简短8字,早日康... 早日康复祝福语简短8字目录早日康复祝福语简短8字早日康复祝福语简短8字搞定手术祝福语8个字早日康复祝...
淘宝店铺名怎么改,淘宝店铺怎么... 淘宝店铺名怎么改目录淘宝店铺名怎么改淘宝店铺怎么改名淘宝店铺名可以修改吗,怎样修改怎么修改淘宝店铺名...
微信助手怎么查单删 极速百科网... 微信助手怎么查单删目录微信助手怎么查单删微信助手怎么查单删微信如何查单删 2016微信如何知道对方有...
陆家嘴都有什么旅游景点 极速百... 陆家嘴都有什么旅游景点目录陆家嘴都有什么旅游景点陆家嘴都有什么旅游景点陆家嘴有哪些旅游景点上海陆家嘴...
1229 - 拦截导弹的系统数... 1229 - 拦截导弹的系统数量求解 题目描述 某国为了防御敌国的导弹袭击,发展出一种...
如何做好项目缺陷管理 缺陷管理是项目管理工作中的重要环节。Excel表格是国内团队常用的缺陷管理工具,具备上...
Python生成器 1.生成器 生成器是一种特殊的迭代器,它是通过函数来实现的。生成器函数每次执行到yi...
Nginx可视化管理工具 - ... 一、介绍 nginx-proxy-manager 是一个反向代理管理系统,它基于Nginx,具有漂亮...
如何解除迅雷安全模式,迅雷怎样... 如何解除迅雷安全模式目录如何解除迅雷安全模式迅雷怎样解除安全模式迅雷VIP尊享版怎么解除安全模式?迅...
感谢朋友圈留言句子,适合发朋友... 感谢朋友圈留言句子目录感谢朋友圈留言句子适合发朋友圈表达感谢的句子20句发朋友圈的感谢短语有哪些?有...
关于韩娱的小说有没有什么好看的... 关于韩娱的小说有没有什么好看的目录关于韩娱的小说有没有什么好看的求好看的韩娱小说有没有好看的韩娱小说...
什么是表面粗糙度(什么是表面粗... 本篇文章极速百科给大家谈谈什么是表面粗糙度,以及什么是表面粗糙度?它对零件的使用性能有什么影响?对应...
永远用英语怎么说,“永远”除了... 永远用英语怎么说目录永远用英语怎么说“永远”除了“forever”的英文翻译~~还有哪些
少年音怎么练,怎么配出清爽的少... 少年音怎么练目录少年音怎么练怎么配出清爽的少年音?怎么学正太音少年音,像是龙马啊、镜音连啊不二啊那种...
情侣之间的爱称有哪些,情侣称呼... 情侣之间的爱称有哪些目录情侣之间的爱称有哪些情侣称呼有创意的爱称情侣之间好听的称呼都有什么?情侣爱称...
共享汽车怎么租车 极速百科网 ... 共享汽车怎么租车目录共享汽车怎么租车共享汽车怎么租车gofun出行有人开吗?使用方法是什么?共享汽车...
Python应用之爬虫基础:r... 引言 在生活中,大家都使用过浏览器,通过输入要搜索的内容以及鼠标点击等操...
jsp医疗辅助诊断管理系统se... 一、源码特点      JSP医疗辅助诊断管理系统是一套完善的java web信息管理系统ÿ...
db19密钥库和加密 创建密钥库ENCRYPTION_WALLET_LOCATION =(SOURCE =...
开局之年是什么意思(开局之年之... 本篇文章极速百科给大家谈谈开局之年是什么意思,以及开局之年之后是什么年对应的知识点,希望对各位有所帮...
抖音gga什么意思(抖音gg是... 本篇文章极速百科给大家谈谈抖音gga什么意思,以及抖音gg是什么意思对应的知识点,希望对各位有所帮助...
DMZ是什么(防火墙的dmz是... 今天给各位分享DMZ是什么的知识,其中也会对防火墙的dmz是什么进行解释,如果能碰巧解决你现在面临的...
风行SX6Sx6后视镜加热打不... 本篇文章极速百科给大家谈谈风行SX6Sx6后视镜加热打不开,以及东风风行sx6反光镜多少钱对应的知识...
CKA-17 Check Da... 文章目录Issue summary:Useful comment:1. 创建场景1.1...
elasticsearch的入... 目录一.数据聚合1.聚合的种类2.DSL实现聚合2.1.Bucket聚合语法2.2.聚合结果排序2....
成都男子误入停车场51秒收费8... 本篇文章极速百科给大家谈谈成都男子误入停车场51秒收费8元,属于乱收费吗,以及成都停车费贵对应的知识...
城市的路灯系统是如何控制开灯和... 本篇文章极速百科给大家谈谈城市的路灯系统是如何控制开灯和熄灯时间的?,以及路灯咋调制,路灯的时控开关...