1、Restormer: Efficient Transformer for High-Resolution Image Restoration
(1)總體結(jié)構(gòu)
總體結(jié)構(gòu)與U-net相近满败。(a)展示了作者改進(jìn)的multi-Dconv head transposed attention 劫瞳,其主要做法是將空間的attention窄驹,轉(zhuǎn)移到通道上,從而可以處理高分辨率圖像。(b)是作者改進(jìn)的Gated-Dconv feed-forward network 掌逛,算是一個(gè)錦上添花的改進(jìn)立镶。
(2)參數(shù)設(shè)置
對(duì)應(yīng)圖中的L1-L4壁袄,分別取值為4,6媚媒,6嗜逻,8。attention heads的數(shù)目依次為 1缭召,2栈顷,4,8嵌巷,特征的通道數(shù)依次為 48萄凤,96,192搪哪,384靡努。L_r的取值為4。優(yōu)化器AdamW晓折,學(xué)習(xí)率由3e-4降至1e-6惑朦,使用cosine annealing策略。此外漓概,使用漸進(jìn)學(xué)習(xí)方式行嗤,在不同的epoch,圖像大小不斷增大垛耳,batchsize數(shù)目不斷變小栅屏。最后,使用了?horizontal and vertical flips數(shù)據(jù)增強(qiáng)堂鲜。
總體來說栈雳,模型訓(xùn)練有許多trick。從表7來看缔莲,模型Flops相對(duì)較小哥纫,但是參數(shù)量較大。
(3)代碼實(shí)現(xiàn)
A.?Multi-DConv Head Transposed Self-Attention (MDTA)
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
? ? def __init__(self, dim, num_heads, bias):
? ? ? ? super(Attention, self).__init__()
? ? ? ? self.num_heads = num_heads #這里是attention的head數(shù)目
? ? ? ? self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
? ? ? ? self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) # q*w,K*w,v*w
? ? ? ? self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) # 可分離卷積
? ? ? ? self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
? ? def forward(self, x):
? ? ? ? b,c,h,w = x.shape
? ? ? ? qkv = self.qkv_dwconv(self.qkv(x))
? ? ? ? q,k,v = qkv.chunk(3, dim=1)?
? ? ? ? q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
? ? ? ? k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
? ? ? ? v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
? ? ? ? q = torch.nn.functional.normalize(q, dim=-1)
? ? ? ? k = torch.nn.functional.normalize(k, dim=-1)
? ? ? ? attn = (q @ k.transpose(-2, -1)) * self.temperature
? ? ? ? attn = attn.softmax(dim=-1)
? ? ? ? out = (attn @ v)
? ? ? ? out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
? ? ? ? out = self.project_out(out)
? ? ? ? return out
B. Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
? ? def __init__(self, dim, ffn_expansion_factor, bias):
? ? ? ? super(FeedForward, self).__init__()
? ? ? ? hidden_features = int(dim*ffn_expansion_factor)
? ? ? ? self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
? ? ? ? self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
? ? ? ? self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
? ? def forward(self, x):
? ? ? ? x = self.project_in(x)
? ? ? ? x1, x2 = self.dwconv(x).chunk(2, dim=1)
? ? ? ? x = F.gelu(x1) * x2
? ? ? ? x = self.project_out(x)