網(wǎng)上有很多關(guān)于pos機p1,PyTorch中實現(xiàn)MultiHead和CBAM的知識,也有很多人為大家解答關(guān)于pos機p1的問題,今天pos機之家(www.afbey.com)為大家整理了關(guān)于這方面的知識,讓我們一起來看下吧!
本文目錄一覽:
1、pos機p1
pos機p1
自從Transformer在“注意力就是你所需要的”的工作中被引入以來,在自然語言處理領(lǐng)域已經(jīng)發(fā)生了一個轉(zhuǎn)變,即用基于注意力的網(wǎng)絡(luò)取代循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)。在當前的文獻中,已經(jīng)有很多很棒的文章描述了這種方法。下面是我在評論中發(fā)現(xiàn)的兩個最好的:帶注釋的Transformer和Transformer的可視化解釋。
然而,在研究了如何在計算機視覺中實現(xiàn)注意力(建議閱讀:Understanding attention Modules, CBAM, Papers with code - attention, Self-Attention, Self-Attention and Conv),我注意到其中只有少數(shù)清楚地描述了注意力機制,包括詳細代碼和理論解釋。因此,本文的目標是詳細描述計算機視覺中兩個最重要的注意力模塊,并將它們應(yīng)用到使用PyTorch的實際案例中。文章結(jié)構(gòu)如下:
注意力模塊介紹計算機視覺中的注意方法基于注意的網(wǎng)絡(luò)的實現(xiàn)和結(jié)果結(jié)論注意力模塊介紹在機器學習中,注意力是一種模仿認知注意力的技術(shù),被定義為選擇并專注于相關(guān)刺激的能力。換句話說,注意力是一種試圖增強重要部分同時淡出不相關(guān)信息的方法。
盡管這種機制可以分為幾個系列,但是我們這里專注于自注意力,因為它是計算機視覺任務(wù)中最受歡迎的注意力類型。這是指將單個序列的不同位置關(guān)聯(lián)起來,以計算同一序列的表示。
為了更好地理解這個概念,讓我們想想下面的句子:Bank of a river。如果我們看不到River這個詞,那么Bank這個詞就失去了它的上下文信息,我們同意這一點嗎?這就是自注意力背后的主旨。它試圖給每個單詞提供上下文信息,因為單詞的個別意思不能代表它們在句子中的意思。
正如《An Intuitive Explanation of Self-attention》中所解釋的,如果我們考慮上面給出的例子,自我注意的作用是將句子中的每個詞與其他詞進行比較,并重新衡量每個詞的詞嵌入向量,以包括上下文相關(guān)性。輸出模塊的輸入是沒有上下文信息的每個單詞的嵌入,輸出是類似的有上下文信息的嵌入。
計算機視覺中的注意力方法這里列出了一個不斷更新的注意力模塊列表。從上面列出的,我們關(guān)注兩個最流行的計算機視覺任務(wù):多頭注意力和卷積塊注意模塊(CBAM)。
多頭注意力
多頭注意力是一種注意機制模塊,它可以多次并行運行一個注意力模塊。因此,要理解它的邏輯,首先需要理解Attention模塊。兩個最常用的注意力函數(shù)是加性注意力和點積注意力,后者是這項工作感興趣的一個。
Attention模塊的基本結(jié)構(gòu)是有兩個向量列表x1和x2,一個是attention,另一個是attached。 向量 x2 生成一個“查詢”,而向量 x1 創(chuàng)建一個“鍵”和一個“值”。 注意力函數(shù)背后的想法是將查詢和設(shè)置的鍵值對映射到輸出。 “輸出計算為值的加權(quán)總和,其中分配給每個值的權(quán)重由查詢與相應(yīng)鍵的兼容性函數(shù)計算,在”[Attention is all you need]論文中 輸出計算如下:
正如本次討論中提到的,鍵/值/查詢概念來自檢索系統(tǒng)。 例如,當在 Youtube 上輸入查詢來搜索某個視頻時,搜索引擎會將您的查詢與數(shù)據(jù)庫中與候選視頻鏈接的一組鍵(視頻標題、描述等)進行映射。 然后,它會為您呈現(xiàn)最匹配的視頻(值)。
在轉(zhuǎn)向多頭注意力之前,讓我們運行這個點積注意力,這是這個模塊的擴展。 下面是 PyTorch 中的實現(xiàn)。 輸入是[128, 32, 1, 256],其中128對應(yīng)batch,32對應(yīng)序列長度,1對應(yīng)head的數(shù)量(對于多個attention head我們會增加),256是特征的數(shù)量 .
class ScaledDotProductAttention(nn.Module):''' Scaled Dot-Product Attention '''def __init__(self, temperature, attn_dropout=0.0):super().__init__()self.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)def forward(self, q, k, v, mask=None):attn = torch.matmul(q / self.temperature, k.transpose(2, 3))if mask is not None:attn = attn.masked_fill(mask == 0, -1e9)attn = self.dropout(F.softmax(attn, dim=-1))output = torch.matmul(attn, v)return output, attn# Attentionquery = torch.rand(128, 32, 1, 256)key = value = torch.rand(128, 16, 1, 256)query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)multihead_attn = ScaledDotProductAttention(temperature=query.size(2))attn_output, attn_weights = multihead_attn(query, key, value)attn_output = attn_output.transpose(1, 2)print(f'attn_output: {attn_output.size()}, attn_weights: {attn_weights.size()}')# Self-attentionquery = torch.rand(128, 32, 1, 256)query = query.transpose(1, 2)multihead_attn = ScaledDotProductAttention(temperature=query.size(2))attn_output, attn_weights = multihead_attn(query, query, query)attn_output = attn_output.transpose(1, 2)print(f'attn_output: {attn_output.size()}, attn_weights: {attn_weights.size()}')
輸出是:
attn_output: [128, 32, 1, 256], attn_weights: [128, 1, 32, 32]attn_output: [128, 32, 1, 256], attn_weights: [128, 1, 32, 16]
這個基本實現(xiàn)的一些要點:
輸出將具有與查詢輸入大小相同的形狀。每個數(shù)據(jù)的注意力權(quán)重必須是一個矩陣,其中行數(shù)對應(yīng)于查詢的序列長度,列數(shù)對應(yīng)于鍵的序列長度。Dot-Product Attention 中沒有可學習的參數(shù)。所以,回到多頭注意力,多頭注意力會同事并行運行這個解釋過的注意力模塊幾次。 然后將獨立的注意力輸出連接起來并線性轉(zhuǎn)換為預(yù)期的維度。 這是實現(xiàn):
class MultiHeadAttention(nn.Module):''' Multi-Head Attention module '''def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):super().__init__()self.n_head = n_headself.d_k = d_kself.d_v = d_vself.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)self.fc = nn.Linear(n_head * d_v, d_model, bias=False)self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)self.dropout = nn.Dropout(dropout)self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)def forward(self, q, k, v, mask=None):d_k, d_v, n_head = self.d_k, self.d_v, self.n_headsz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)residual = q# Pass through the pre-attention projection: b x lq x (n*dv)# Separate different heads: b x lq x n x dvq = self.w_qs(q).view(sz_b, len_q, n_head, d_k)k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)# Transpose for attention dot product: b x n x lq x dvq, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)if mask is not None:mask = mask.unsqueeze(1) # For head axis broadcasting.q, attn = self.attention(q, k, v, mask=mask)# Transpose to move the head dimension back: b x lq x n x dv# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)q = self.dropout(self.fc(q))q += residualq = self.layer_norm(q)return q, attnquery = torch.rand(128, 32, 256)multihead_attn = MultiHeadAttention(n_head=8, d_model=256, d_k=32, d_v=32)attn_output, attn_weights = multihead_attn(query, query, query)print(f'attn_output: {attn_output.size()}, attn_weights: {attn_weights.size()}')query = torch.rand(128, 32, 256)multihead_attn = MultiHeadAttention(n_head=8, d_model=256, d_k=256, d_v=512)attn_output, attn_weights = multihead_attn(query, query, query)print(f'attn_output: {attn_output.size()}, attn_weights: {attn_weights.size()}')
輸出是:
attn_output: [128, 32, 256], attn_weights: [128, 8, 32, 32]attn_output: [128, 32, 256], attn_weights: [128, 8, 32, 32]
從代碼中,我們看到:
例如,用于查詢的線性層的輸入是 [128, 32, 256]。 但是,正如本文所述,線性層接受任意形狀的張量,其中只有最后一個維度必須與您在構(gòu)造函數(shù)中指定的 in_features 參數(shù)匹配。 輸出將具有與輸入完全相同的形狀,只有最后一個維度會更改為您在構(gòu)造函數(shù)中指定為 out_features 的任何內(nèi)容。 對于我們的例子,輸入形狀是一組 128 * 32 = 4096 和 256 個特征。 因此,我們將密集網(wǎng)絡(luò)(線性層)應(yīng)用于序列長度的每個元素和批次的每個數(shù)據(jù)。
此外,我們添加了殘差連接和層歸一化,因為它是在 Transformer 神經(jīng)網(wǎng)絡(luò)中實現(xiàn)的。 但是,如果您只想實現(xiàn)多頭注意力模塊,則應(yīng)該排除這些。
那么,此時你可能想知道,為什么我們要實現(xiàn) Multi-Head Attention 而不是一個簡單的 Attention 模塊? 根據(jù)論文Attention is all you need,“多頭注意力允許模型共同關(guān)注來自不同位置的不同表示子空間的信息。 用一個注意力頭,平均值會抑制這一點?!?換句話說,將特征劃分為頭部允許每個注意力模塊只關(guān)注一組特征,從而為每個單詞編碼多個關(guān)系和細微差別提供更大的能力。
在結(jié)束之前,我只想提一下,我們已經(jīng)使用了這個注意力模塊,就好像我們在處理序列一樣,但這篇文章是關(guān)于圖像的。 如果您已經(jīng)理解了所有解釋的內(nèi)容,那么序列和圖像之間的唯一區(qū)別就是輸入向量。 對應(yīng)于序列長度的,對于圖像來說就是像素。 因此,如果輸入是 [batch=128, no_channels=256, height=24, width="360px",height="auto" />
query = torch.rand(128, 256, 24, 24)query_ = torch.reshape(query, (query.size(0), -1 , query.size(1)))multihead_attn = MultiHeadAttention(n_head=8, d_model=256, d_k=32, d_v=32)attn_output, attn_weights = multihead_attn(query_, query_, query_)attn_output = attn_output.reshape(*list(query.size()))print(f'attn_output: {attn_output.size()}, attn_weights: {attn_weights.size()}')
輸出是:
attn_output: [128, 256, 24, 24], attn_weights: [128, 8, 576, 576]
卷積塊注意力模塊 (CBAM)
2018 年,S. Woo 等人。 (2018) 發(fā)布了一個名為卷積塊注意力模塊 (CBAM) 的新注意力模塊,與卷積操作一樣,它強調(diào)了沿通道和空間軸的有意義的特征。 與多頭注意力相比,這種注意力是專門為前饋卷積神經(jīng)網(wǎng)絡(luò)而設(shè)計的,可以應(yīng)用于深度網(wǎng)絡(luò)中的每個卷積塊。
CBAM 包含兩個連續(xù)的子模塊,稱為通道注意模塊 (CAM) 和空間注意模塊 (SAM)。 在談到卷積時,這兩個概念可能是最重要的兩個概念。 通道是指每個像素的特征或通道的數(shù)量,而空間是指維度(h x w)的特征圖。
這是實現(xiàn):
class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planesself.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else Noneself.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)if self.bn is not None:x = self.bn(x)if self.relu is not None:x = self.relu(x)return xclass Flatten(nn.Module):def forward(self, x):return x.view(x.size(0), -1)class ChannelGate(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):super(ChannelGate, self).__init__()self.gate_channels = gate_channelsself.mlp = nn.Sequential(Flatten(),nn.Linear(gate_channels, gate_channels // reduction_ratio),nn.ReLU(),nn.Linear(gate_channels // reduction_ratio, gate_channels))self.pool_types = pool_typesdef forward(self, x):channel_att_sum = Nonefor pool_type in self.pool_types:if pool_type=='avg':avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))channel_att_raw = self.mlp( avg_pool )elif pool_type=='max':max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))channel_att_raw = self.mlp( max_pool )elif pool_type=='lp':lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))channel_att_raw = self.mlp( lp_pool )elif pool_type=='lse':# LSE pool onlylse_pool = logsumexp_2d(x)channel_att_raw = self.mlp( lse_pool )if channel_att_sum is None:channel_att_sum = channel_att_rawelse:channel_att_sum = channel_att_sum + channel_att_rawscale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)return x * scaledef logsumexp_2d(tensor):tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()return outputsclass ChannelPool(nn.Module):def forward(self, x):return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )class SpatialGate(nn.Module):def __init__(self):super(SpatialGate, self).__init__()kernel_size = 7self.compress = ChannelPool()self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)def forward(self, x):x_compress = self.compress(x)x_out = self.spatial(x_compress)scale = torch.sigmoid(x_out) # broadcastingreturn x * scaleclass CBAM(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=True):super(CBAM, self).__init__()self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)self.no_spatial = no_spatialif not no_spatial:self.SpatialGate = SpatialGate()def forward(self, x):x_out = self.ChannelGate(x)if not self.no_spatial:x_out = self.SpatialGate(x_out)return x_outquery = torch.rand(128, 256, 24, 24)attn = CBAM(gate_channels=256)attn_output = attn(query)print(attn_output.size())
輸出是:
attn_output: [128, 256, 24, 24]
基于注意力的網(wǎng)絡(luò)的實現(xiàn)和結(jié)果
在上面介紹的理論部分之后,本節(jié)重點介紹兩個注意力層在實際案例中的實現(xiàn)。
具體來說,我們選擇了 STL 數(shù)據(jù)集,并在一些圖像中包含了一個白色補丁,如下所示。 任務(wù)是創(chuàng)建一個神經(jīng)網(wǎng)絡(luò),對兩種類型的圖像進行分類。
from torchvision.datasets import STL10dataset = STL10("stl10", split='train', download=True)def getBatch(BS=10, offset=0, display_labels=False):xs = []labels = []for i in range(BS):x, y = dataset[offset + i]x = (np.array(x)-128.0)/128.0x = x.transpose(2, 0, 1)np.random.seed(i + 10)corrupt = np.random.randint(2)if corrupt: # To corrupt the image, we'll just copy a patch from somewhere elsepos_x = np.random.randint(96-16)pos_y = np.random.randint(96-16)x[:, pos_x:pos_x+16, pos_y:pos_y+16] = 1xs.append(x)labels.append(corrupt)if display_labels == True:print(labels)return np.array(xs), np.array(labels)
STL 圖像。 標記為 1 的圖像屬于圖像有白斑的類別,而標記為 0 的圖像是沒有白斑的。
class ConvPart(nn.Module):def __init__(self):super().__init__()self.c1a = nn.Conv2d(3, 32, 5, padding=2)self.p1 = nn.MaxPool2d(2)self.c2a = nn.Conv2d(32, 32, 5, padding=2)self.p2 = nn.MaxPool2d(2)self.c3 = nn.Conv2d(32, 32, 5, padding=2)self.bn1a = nn.BatchNorm2d(32)self.bn2a = nn.BatchNorm2d(32)def forward(self, x):z = self.bn1a(F.leaky_relu(self.c1a(x)))z = self.p1(z)z = self.bn2a(F.leaky_relu(self.c2a(z)))z = self.p2(z)z = self.c3(z)return zclass Net(nn.Module):def __init__(self):super().__init__()self.conv = ConvPart()self.final = nn.Linear(32, 1)self.optim = torch.optim.Adam(self.parameters(), lr=1e-4)def forward(self, x):z = self.conv(x)z = z.mean(3).mean(2)p = torch.sigmoid(self.final(z))[:, 0]return p, _class NetMultiheadAttention(nn.Module):def __init__(self):super().__init__()self.conv = ConvPart()self.attn1 = MultiHeadAttention(n_head=4, d_model=32, d_k=8, d_v=8)self.final = nn.Linear(32, 1)self.optim = torch.optim.Adam(self.parameters(), lr=1e-4)def forward(self, x):z = self.conv(x)q = torch.reshape(z, (z.size(0), -1 , z.size(1)))q, w = self.attn1(q, q, q)q = torch.reshape(q, (z.size(0), z.size(1), z.size(2), z.size(3)))z = q.mean(3).mean(2)p = torch.sigmoid(self.final(z))[:, 0]return p, qclass NetCBAM(nn.Module):def __init__(self):super().__init__()self.conv = ConvPart()self.attn1 = CBAM(gate_channels=32)self.final = nn.Linear(32, 1)self.optim = torch.optim.Adam(self.parameters(), lr=1e-4)def forward(self, x):z = self.conv(x)q = self.attn1(z)z = q.mean(3).mean(2)p = torch.sigmoid(self.final(z))[:, 0]return p, q
這是運行訓(xùn)練的代碼。
import timeimport numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport IPython.display as displaydevice = 'cuda' if torch.cuda.is_available() else torch.device('cpu')print(device)def plot_without_attention(tr_err, ts_err, tr_acc, ts_acc, img):plt.clf()fig, axs = plt.subplots(1, 4, figsize=(20, 5))axs[0].plot(tr_err, label='tr_err')axs[0].plot(ts_err, label='ts_err')axs[0].legend()axs[1].plot(tr_acc, label='tr_err')axs[1].plot(ts_acc, label='ts_err')axs[1].legend()axs[2].axis('off')axs[3].axis('off')display.clear_output(wait=True)display.display(plt.gcf())time.sleep(0.01)def plot_with_attention(tr_err, ts_err, tr_acc, ts_acc, img, att_out, no_images=6):plt.clf()fig, axs = plt.subplots(1+no_images, 4, figsize=(20, (no_images+1)*5))axs[0, 0].plot(tr_err, label='tr_err')axs[0, 0].plot(ts_err, label='ts_err')axs[0, 0].legend()axs[0, 1].plot(tr_acc, label='tr_err')axs[0, 1].plot(ts_acc, label='ts_err')axs[0, 1].legend()axs[0, 2].axis('off')axs[0, 3].axis('off')for img_no in range(6):im = img[img_no].cpu().detach().numpy().transpose(1, 2, 0)*0.5 + 0.5axs[img_no+1, 0].imshow(im)for i in range(3):att_out_img = att_out[img_no, i+1].cpu().detach().numpy()axs[img_no+1, i+1].imshow(att_out_img)display.clear_output(wait=True)display.display(plt.gcf())time.sleep(0.01)def train(model, att_flag=False):net = model.to(device)tr_err, ts_err = [], []tr_acc, ts_acc = [], []for epoch in range(50):errs, accs = [], []net.train()for i in range(4000//BATCH_SIZE):net.optim.zero_grad()x, y = getBatch(BATCH_SIZE, i*BATCH_SIZE)x = torch.FloatTensor(x).to(device)y = torch.FloatTensor(y).to(device)p, q = net.forward(x)loss = -torch.mean(y*torch.log(p+1e-8) + (1-y)*torch.log(1-p+1e-8))loss.backward()errs.append(loss.cpu().detach().item())pred = torch.round(p)accs.append(torch.sum(pred == y).cpu().detach().item()/BATCH_SIZE)net.optim.step() tr_err.append(np.mean(errs))tr_acc.append(np.mean(accs))errs, accs = [], []net.eval()for i in range(1000//BATCH_SIZE):x, y = getBatch(BATCH_SIZE, i*BATCH_SIZE+4000)x = torch.FloatTensor(x).to(device)y = torch.FloatTensor(y).to(device)p, q = net.forward(x) loss = -torch.mean(y*torch.log(p+1e-8) + (1-y)*torch.log(1-p+1e-8))errs.append(loss.cpu().detach().item())pred = torch.round(p)accs.append(torch.sum(pred == y).cpu().detach().item()/BATCH_SIZE)ts_err.append(np.mean(errs)) ts_acc.append(np.mean(accs))if att_flag == False:plot_without_attention(tr_err, ts_err, tr_acc, ts_acc, x[0])else:plot_with_attention(tr_err, ts_err, tr_acc, ts_acc, x, q)print(f'Min train error: {np.min(tr_err)}')print(f'Min test error: {np.min(ts_err)}')
CNN輸出:
Min train error: 0.0011167450276843738Min test error: 0.05411996720208516
訓(xùn)練:
model = Net()train(model, att_flag=False)
CNN + Multi-Head attention:添加注意力層時性能有所提高,但注意力圖沒有突出顯示帶有白斑的圖像部分。
Min train error: 9.811600781858942e-06Min test error: 0.04209221125441423
訓(xùn)練:
model = NetMultiheadAttention()train(model, att_flag=True)
由于存在一些過擬合并且注意力層沒有完成它應(yīng)該做的事情,我使用卷積層重新實現(xiàn)了這一層。
CNN + 1DConv-based Multi-Head attention:這一次,穩(wěn)定性和性能顯著提升。 此外,還可以觀察注意力層的輸出如何突出顯示包含它的圖像的白色塊。
Min train error: 0.00025470180017873645Min test error: 0.014278276459193759
注意力代碼
class Attention(nn.Module):def __init__(self, mem_in=32, query_in=32, key_size=32, output_size=32):super(Attention, self).__init__()self.key = nn.Conv1d(mem_in, key_size, 1, padding=0)self.value = nn.Conv1d(mem_in, output_size, 1, padding=0)self.query = nn.Conv1d(query_in, key_size, 1, padding=0)self.key_size = key_sizedef forward(self, x1, x2):queries = self.query(x1) # Batch x Values x Keyskeys = self.key(x2) # Batch x Keysize x Keysvalues = self.value(x2) # Batch x Values x Keysu = torch.sum(queries.unsqueeze(2) * keys.unsqueeze(3), 1)/np.sqrt(self.key_size)w = F.softmax(u, dim=1)out = torch.sum(w.unsqueeze(1) * values.unsqueeze(3), 2)return out, wclass MultiheadAttention(nn.Module):def __init__(self, mem_in=32, query_in=32, key_size=32, output_size=32, num_heads=4):super(MultiheadAttentionModified, self).__init__()self.layers = nn.ModuleList([Attention(mem_in, query_in, key_size, output_size) for i in range(num_heads)])self.proj_down = nn.Conv1d(num_heads*output_size, query_in, 1, padding=0)self.mixing_layer1 = nn.Conv1d(query_in, query_in, 1, padding=0)self.mixing_layer2 = nn.Conv1d(query_in, query_in, 1, padding=0)self.norm1 = nn.LayerNorm(query_in)self.norm2 = nn.LayerNorm(query_in)def forward(self, query, context):x1 = query.reshape(query.size(0), query.size(1), -1)x2 = context.reshape(context.size(0), context.size(1), -1)# Apply attention for each headz1, ws = [], []for i in range(len(self.layers)):z, w = self.layers[i](x1, x2)z1.append(z)ws.append(w)z1 = torch.cat(z1, 1)# Project down. Layer norm is a bit fiddly here - it wants the dimensions to normalize over to be the last dimensionsz2 = self.norm1((self.proj_down(z1) + x2).transpose(1, 2).contiguous()).transpose(1, 2).contiguous()# Mixing layerz3 = self.norm2((self.mixing_layer2(F.relu(self.mixing_layer1(z2))) + z2).transpose(1, 2).contiguous()).transpose(1, 2).contiguous()if len(query.size()) == 4:z3 = z3.reshape(query.size(0), query.size(1), query.size(3), query.size(3)) return z3, z1class NetMultiheadAttention(nn.Module):def __init__(self):super().__init__()self.conv = ConvPart()self.attn1 = MultiheadAttention(mem_in=32, query_in=32)self.final = nn.Linear(32, 1)self.optim = torch.optim.Adam(self.parameters(), lr=1e-4)def forward(self, x):z = self.conv(x)q = torch.reshape(z, (z.size(0) , z.size(1), -1))q, w = self.attn1(q, q)q = torch.reshape(q, (z.size(0), z.size(1), z.size(2), z.size(3)))z = q.mean(3).mean(2)p = torch.sigmoid(self.final(z))[:, 0]return p, q
CNN + CBAM attention:這個是最好的結(jié)果。 顯然可以觀察到注意力層輸出中的白斑,并且訓(xùn)練非常穩(wěn)定,實現(xiàn)了所有模型的最低驗證損失。
Min train error: 2.786791462858673e-05Min test error: 0.028047989653949175
訓(xùn)練
model = NetCBAM()train(model, att_flag=True)總結(jié)
本文介紹了多頭注意力和 CBAM 模塊,這是計算機視覺中最流行的兩個注意力模塊。 此外,它還包括 PyTorch 中的一個實現(xiàn),我們從包含白斑(手動添加)的 CIFAR 數(shù)據(jù)集中對圖像進行分類。
對于未來的工作,我認為將位置編碼與注意力一起包括在內(nèi)是很有趣的。以后我們會翻譯這方面的文章。
作者:Javier Fernandez
deephub翻譯組
以上就是關(guān)于pos機p1,PyTorch中實現(xiàn)MultiHead和CBAM的知識,后面我們會繼續(xù)為大家整理關(guān)于pos機p1的知識,希望能夠幫助到大家!
