計算機視覺3-沐神筆記篇

語義分割和數(shù)據(jù)集

語義分割(Semantic Segmentation)是計算機視覺中的一項任務(wù),旨在將圖像中的每個像素標記為屬于不同語義類別的一部分砸烦。與傳統(tǒng)的圖像分類任務(wù)不同驱富,語義分割需要對圖像中的每個像素進行分類牙甫,從而實現(xiàn)對圖像的像素級別理解和分割。

語義分割的目標是將圖像中的不同物體或區(qū)域進行分割热康,并為每個像素分配一個特定的語義標簽消返。這意味著圖像中的每個像素都被標記為屬于不同的類別载弄,如人、車撵颊、樹等。通過進行語義分割惫叛,我們可以獲得關(guān)于圖像中各個對象和區(qū)域的詳細信息倡勇,為場景理解、目標檢測嘉涌、圖像分析等任務(wù)提供基礎(chǔ)妻熊。

在語義分割中,常用的方法是使用深度學(xué)習(xí)模型仑最,特別是卷積神經(jīng)網(wǎng)絡(luò)(CNN)扔役。卷積神經(jīng)網(wǎng)絡(luò)在圖像處理中具有出色的性能,能夠從圖像中學(xué)習(xí)到豐富的特征表示警医。通常亿胸,語義分割模型使用編碼-解碼架構(gòu),其中編碼器負責(zé)提取圖像的特征表示预皇,而解碼器則將特征映射到像素級別的預(yù)測侈玄。

近年來,一些先進的語義分割模型如U-Net吟温、SegNet序仙、DeepLab等已被提出,并在圖像分割任務(wù)中取得了顯著的進展鲁豪。這些模型結(jié)合了卷積神經(jīng)網(wǎng)絡(luò)的強大特征提取能力和適應(yīng)性潘悼,使得語義分割在許多應(yīng)用領(lǐng)域取得了重要的突破,如自動駕駛爬橡、醫(yī)學(xué)圖像分析治唤、遙感圖像解譯等。

總而言之堤尾,語義分割是一項重要的計算機視覺任務(wù)肝劲,旨在對圖像進行像素級別的分割和分類,為對圖像中各個對象和區(qū)域的理解提供了強大的工具和方法。

我們一直使用方形邊界框來標注和預(yù)測圖像中的目標辞槐。 本節(jié)將探討語義分割(semantic segmentation)問題掷漱,它重點關(guān)注于如何將圖像分割成屬于不同語義類別的區(qū)域。 與目標檢測不同榄檬,語義分割可以識別并理解圖像中每一個像素的內(nèi)容:其語義區(qū)域的標注和預(yù)測是像素級的卜范。

語義分割中圖像有關(guān)狗、貓和背景的標簽

Pascal VOC2012 語義分割數(shù)據(jù)集

最重要的語義分割數(shù)據(jù)集之一是Pascal VOC2012鹿榜。 下面我們深入了解一下這個數(shù)據(jù)集海雪。
VOC(Visual Object Classes)格式是一種常用的圖像標注和物體檢測數(shù)據(jù)集格式,通常用于計算機視覺中的目標檢測舱殿、圖像分割和分類任務(wù)奥裸。VOC格式由PASCAL VOC(Pattern Analysis, Statistical Modeling and Computational Learning Visual Object Classes)項目定義,并在圖像識別研究中得到廣泛應(yīng)用沪袭。

VOC格式的數(shù)據(jù)集通常由以下幾個組成部分組成:

  1. 圖像文件(Image Files):包含原始的圖像文件湾宙,以常見的圖像格式(如JPEG、PNG等)保存冈绊。

  2. 標注文件(Annotation Files):以XML格式保存侠鳄,每個標注文件對應(yīng)于一個圖像。標注文件包含了圖像中每個目標物體的邊界框(Bounding Box)和類別標簽(Class Label)死宣。每個邊界框由左上角和右下角的坐標表示伟恶,以及相應(yīng)的類別標簽。

  3. 類別標簽文件(Class Labels File):以文本文件形式提供毅该,包含數(shù)據(jù)集中所使用的所有類別標簽博秫。每個類別標簽占據(jù)一行,可以是物體類別的名稱或數(shù)字標識符鹃骂。

VOC格式的數(shù)據(jù)集通常按照特定目標檢測任務(wù)的需求進行標注和組織台盯,提供了標準化的數(shù)據(jù)格式,方便各種目標檢測算法和模型的訓(xùn)練和評估畏线。此外静盅,VOC格式還定義了一些評估指標,如平均精度(mAP)寝殴,用于評估目標檢測模型的性能蒿叠。

需要注意的是,VOC格式只是一種數(shù)據(jù)集組織和標注的規(guī)范蚣常,并不限定特定的圖像處理庫或軟件工具市咽。在使用VOC格式的數(shù)據(jù)集時,可以選擇適合自己任務(wù)的圖像處理庫(如OpenCV抵蚊、TensorFlow施绎、PyTorch等)或相關(guān)工具進行數(shù)據(jù)的讀取溯革、處理和訓(xùn)練。

VOC_ROOT     #根目錄
    ├── JPEGImages         # 存放源圖片
    │     ├── aaaa.jpg     
    │     ├── bbbb.jpg  
    │     └── cccc.jpg
    ├── Annotations        # 存放xml文件谷醉,與JPEGImages中的圖片一一對應(yīng)致稀,解釋圖片的內(nèi)容等等
    │     ├── aaaa.xml 
    │     ├── bbbb.xml 
    │     └── cccc.xml 
    └── ImageSets          
        └── Main
          ├── train.txt    # txt文件中每一行包含一個圖片的名稱
          └── val.txt

下面將read_voc_images函數(shù)定義為將所有輸入的圖像和標簽讀入內(nèi)存。

#@save
def read_voc_images(voc_dir, is_train=True):
    """讀取所有VOC圖像并標注"""
    txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',
                             'train.txt' if is_train else 'val.txt')
    mode = torchvision.io.image.ImageReadMode.RGB
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    features, labels = [], []
    for i, fname in enumerate(images):
        features.append(torchvision.io.read_image(os.path.join(
            voc_dir, 'JPEGImages', f'{fname}.jpg')))
        labels.append(torchvision.io.read_image(os.path.join(
            voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))
    return features, labels

train_features, train_labels = read_voc_images(voc_dir, True)

下面我們繪制前5個輸入圖像及其標簽俱尼。 在標簽圖像中抖单,白色和黑色分別表示邊框和背景,而其他顏色則對應(yīng)不同的類別遇八。

n = 5
imgs = train_features[0:n] + train_labels[0:n]
imgs = [img.permute(1,2,0) for img in imgs]
d2l.show_images(imgs, 2, n);

接下來矛绘,我們列舉RGB顏色值和類名。

#@save
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

#@save
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

通過上面定義的兩個常量刃永,我們可以方便地查找標簽中每個像素的類索引货矮。 我們定義了voc_colormap2label函數(shù)來構(gòu)建從上述RGB顏色值到類別索引的映射,而voc_label_indices函數(shù)將RGB值映射到在Pascal VOC2012數(shù)據(jù)集中的類別索引揽碘。

#@save
def voc_colormap2label():
    """構(gòu)建從RGB到VOC類別索引的映射"""
    colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
    for i, colormap in enumerate(VOC_COLORMAP):
        colormap2label[
            (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
    return colormap2label

#@save
def voc_label_indices(colormap, colormap2label):
    """將VOC標簽中的RGB值映射到它們的類別索引"""
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]

例如次屠,在第一張樣本圖像中,飛機頭部區(qū)域的類別索引為1雳刺,而背景索引為0。

y = voc_label_indices(train_labels[0], voc_colormap2label())
y[105:115, 130:140], VOC_CLASSES[1]
(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]),
 'aeroplane')

在之前的實驗裸违,我們通過再縮放圖像使其符合模型的輸入形狀掖桦。 然而在語義分割中,這樣做需要將預(yù)測的像素類別重新映射回原始尺寸的輸入圖像供汛。 這樣的映射可能不夠精確枪汪,尤其在不同語義的分割區(qū)域。 為了避免這個問題怔昨,我們將圖像裁剪為固定尺寸雀久,而不是再縮放。 具體來說趁舀,我們使用圖像增廣中的隨機裁剪赖捌,裁剪輸入圖像和標簽的相同區(qū)域。

#@save
def voc_rand_crop(feature, label, height, width):
    """隨機裁剪特征和標簽圖像"""
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label

imgs = []
for _ in range(n):
    imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300)

imgs = [img.permute(1, 2, 0) for img in imgs]
d2l.show_images(imgs[::2] + imgs[1::2], 2, n);

自定義語義分割數(shù)據(jù)集類

我們通過繼承高級API提供的Dataset類矮烹,自定義了一個語義分割數(shù)據(jù)集類VOCSegDataset越庇。 通過實現(xiàn)getitem函數(shù),我們可以任意訪問數(shù)據(jù)集中索引為idx的輸入圖像及其每個像素的類別索引奉狈。 由于數(shù)據(jù)集中有些圖像的尺寸可能小于隨機裁剪所指定的輸出尺寸卤唉,這些樣本可以通過自定義的filter函數(shù)移除掉。 此外仁期,我們還定義了normalize_image函數(shù)桑驱,從而對輸入圖像的RGB三個通道的值分別做標準化竭恬。

#@save
class VOCSegDataset(torch.utils.data.Dataset):
    """一個用于加載VOC數(shù)據(jù)集的自定義數(shù)據(jù)集"""

    def __init__(self, is_train, crop_size, voc_dir):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        features, labels = read_voc_images(voc_dir, is_train=is_train)
        self.features = [self.normalize_image(feature)
                         for feature in self.filter(features)]
        self.labels = self.filter(labels)
        self.colormap2label = voc_colormap2label()
        print('read ' + str(len(self.features)) + ' examples')

    def normalize_image(self, img):
        return self.transform(img.float() / 255)

    def filter(self, imgs):
        return [img for img in imgs if (
            img.shape[1] >= self.crop_size[0] and
            img.shape[2] >= self.crop_size[1])]

    def __getitem__(self, idx):
        feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
                                       *self.crop_size)
        return (feature, voc_label_indices(label, self.colormap2label))

    def __len__(self):
        return len(self.features)

讀取數(shù)據(jù)集

我們通過自定義的VOCSegDataset類來分別創(chuàng)建訓(xùn)練集和測試集的實例。 假設(shè)我們指定隨機裁剪的輸出圖像的形狀為320\times 480熬的, 下面我們可以查看訓(xùn)練集和測試集所保留的樣本個數(shù)痊硕。

crop_size = (320, 480)
voc_train = VOCSegDataset(True, crop_size, voc_dir)
voc_test = VOCSegDataset(False, crop_size, voc_dir)

設(shè)批量大小為64,我們定義訓(xùn)練集的迭代器悦析。 打印第一個小批量的形狀會發(fā)現(xiàn):與圖像分類或目標檢測不同寿桨,這里的標簽是一個三維數(shù)組。

batch_size = 64
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True,
                                    drop_last=True,
                                    num_workers=d2l.get_dataloader_workers())
for X, Y in train_iter:
    print(X.shape)
    print(Y.shape)
    break
torch.Size([64, 3, 320, 480])
torch.Size([64, 320, 480])

整合所有組件

最后强戴,我們定義以下load_data_voc函數(shù)來下載并讀取Pascal VOC2012語義分割數(shù)據(jù)集亭螟。 它返回訓(xùn)練集和測試集的數(shù)據(jù)迭代器。

#@save
def load_data_voc(batch_size, crop_size):
    """加載VOC語義分割數(shù)據(jù)集"""
    voc_dir = d2l.download_extract('voc2012', os.path.join(
        'VOCdevkit', 'VOC2012'))
    num_workers = d2l.get_dataloader_workers()
    train_iter = torch.utils.data.DataLoader(
        VOCSegDataset(True, crop_size, voc_dir), batch_size,
        shuffle=True, drop_last=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(
        VOCSegDataset(False, crop_size, voc_dir), batch_size,
        drop_last=True, num_workers=num_workers)
    return train_iter, test_iter

小結(jié)

  • 語義分割通過將圖像劃分為屬于不同語義類別的區(qū)域骑歹,來識別并理解圖像中像素級別的內(nèi)容预烙。

  • 語義分割的一個重要的數(shù)據(jù)集叫做Pascal VOC2012。

  • 由于語義分割的輸入圖像和標簽在像素上一一對應(yīng)道媚,輸入圖像會被隨機裁剪為固定尺寸而不是縮放扁掸。

轉(zhuǎn)置卷積

到目前為止,我們所見到的卷積神經(jīng)網(wǎng)絡(luò)層最域,例如卷積層和匯聚層谴分,通常會減少下采樣輸入圖像的空間維度(高和寬)。 然而如果輸入和輸出圖像的空間維度相同镀脂,在以像素級分類的語義分割中將會很方便牺蹄。 例如,輸出像素所處的通道維可以保有輸入像素在同一位置上的分類結(jié)果薄翅。

為了實現(xiàn)這一點沙兰,尤其是在空間維度被卷積神經(jīng)網(wǎng)絡(luò)層縮小后,我們可以使用另一種類型的卷積神經(jīng)網(wǎng)絡(luò)層翘魄,它可以增加上采樣中間層特征圖的空間維度鼎天。 本節(jié)將介紹 轉(zhuǎn)置卷積, 用于逆轉(zhuǎn)下采樣導(dǎo)致的空間尺寸減小暑竟。

轉(zhuǎn)置卷積(Transpose Convolution)斋射,也被稱為反卷積(Deconvolution)或上采樣卷積(Upsampling Convolution),是卷積神經(jīng)網(wǎng)絡(luò)(CNN)中的一種操作光羞。轉(zhuǎn)置卷積可以將低維特征圖(例如绩鸣,輸入圖像的較低分辨率特征圖)通過反向操作進行上采樣,從而得到更高分辨率的特征圖纱兑。

轉(zhuǎn)置卷積的原理可以簡單描述如下:在傳統(tǒng)的卷積操作中呀闻,我們使用卷積核(filter)對輸入特征圖進行卷積運算,從而得到下采樣(降低分辨率)后的特征圖潜慎。而轉(zhuǎn)置卷積則是對下采樣后的特征圖進行上采樣操作捡多,通過填充空白像素和應(yīng)用反向的卷積核蓖康,將低分辨率的特征圖還原為高分辨率的特征圖。

在實際實現(xiàn)中垒手,轉(zhuǎn)置卷積可以使用多種方式來完成蒜焊,其中最常見的方法是使用反向卷積操作。在反向卷積中科贬,輸入特征圖中的每個像素都會與卷積核中的權(quán)重進行相乘泳梆,并在輸出特征圖中的對應(yīng)位置進行求和。通過在輸出特征圖的像素之間插入填充值榜掌,可以實現(xiàn)上采樣的效果优妙。

轉(zhuǎn)置卷積在深度學(xué)習(xí)中的應(yīng)用非常廣泛,特別是在圖像分割憎账、物體檢測和圖像生成等任務(wù)中套硼。通過使用轉(zhuǎn)置卷積,可以將低分辨率的特征圖還原為高分辨率胞皱,從而有助于提高模型的性能和精度邪意。

基本操作

2\times 2的輸入張量計算卷積核為2\times 2的轉(zhuǎn)置卷積。

卷積核為 2x2的轉(zhuǎn)置卷積反砌。陰影部分是中間張量的一部分雾鬼,也是用于計算的輸入和卷積核張量元素。

我們可以對輸入矩陣X和卷積核矩陣K實現(xiàn)基本的轉(zhuǎn)置卷積運算trans_conv宴树。

def trans_conv(X, K):
    h, w = K.shape
    Y = torch.zeros((X.shape[0] + h - 1, X.shape[1] + w - 1))
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            Y[i: i + h, j: j + w] += X[i, j] * K
    return Y
X = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
trans_conv(X, K)
tensor([[ 0.,  0.,  1.],
        [ 0.,  4.,  6.],
        [ 4., 12.,  9.]])

或者呆贿,當輸入X和卷積核K都是四維張量時,我們可以使用高級API獲得相同的結(jié)果森渐。

X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, bias=False)
tconv.weight.data = K
tconv(X)
tensor([[[[ 0.,  0.,  1.],
          [ 0.,  4.,  6.],
          [ 4., 12.,  9.]]]], grad_fn=<ConvolutionBackward0>)

填充、步幅和多通道

與常規(guī)卷積不同冒晰,在轉(zhuǎn)置卷積中同衣,填充被應(yīng)用于的輸出(常規(guī)卷積將填充應(yīng)用于輸入)。 例如壶运,當將高和寬兩側(cè)的填充數(shù)指定為1時耐齐,轉(zhuǎn)置卷積的輸出中將刪除第一和最后的行與列。

tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, padding=1, bias=False)
tconv.weight.data = K
tconv(X)
tensor([[[[4.]]]], grad_fn=<ConvolutionBackward0>)

在轉(zhuǎn)置卷積中蒋情,步幅被指定為中間結(jié)果(輸出)埠况,而不是輸入。將步幅從1更改為2會增加中間張量的高和權(quán)重棵癣。


卷積核為2x2辕翰,步幅為2的轉(zhuǎn)置卷積。陰影部分是中間張量的一部分狈谊,也是用于計算的輸入和卷積核張量元素喜命。
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
tconv.weight.data = K
tconv(X)
tensor([[[[0., 0., 0., 1.],
          [0., 0., 2., 3.],
          [0., 2., 0., 3.],
          [4., 6., 6., 9.]]]], grad_fn=<ConvolutionBackward0>)

與矩陣變換的聯(lián)系

轉(zhuǎn)置卷積為何以矩陣變換命名呢沟沙? 讓我們首先看看如何使用矩陣乘法來實現(xiàn)卷積。 在下面的示例中壁榕,我們定義了一個3\times 3的輸入X和2\times 2卷積核K矛紫,然后使用corr2d函數(shù)計算卷積輸出Y。

X = torch.arange(9.0).reshape(3, 3)
K = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
Y = d2l.corr2d(X, K)
Y
def kernel2matrix(K):
    k, W = torch.zeros(5), torch.zeros((4, 9))
    k[:2], k[3:5] = K[0, :], K[1, :]
    W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, k
    return W

W = kernel2matrix(K)
W

接下來牌里,我們將卷積核K重寫為包含大量0的稀疏權(quán)重矩陣W颊咬。 權(quán)重矩陣的形狀是(4,9)牡辽,其中非0元素來自卷積核K喳篇。為了判斷卷積的操作是否等于某種矩陣的變換。

Y == torch.matmul(W, X.reshape(-1)).reshape(2, 2)
tensor([[True, True],
        [True, True]])

同樣催享,我們可以使用矩陣乘法來實現(xiàn)轉(zhuǎn)置卷積杭隙。 在下面的示例中,我們將上面的常規(guī)卷積2 \times 2的輸出Y作為轉(zhuǎn)置卷積的輸入因妙。 想要通過矩陣相乘來實現(xiàn)它痰憎,我們只需要將權(quán)重矩陣W的形狀轉(zhuǎn)置為(9, 4)

Z = trans_conv(Y, K)
Z == torch.matmul(W.T, Y.reshape(-1)).reshape(3, 3)
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

小結(jié)

全卷積網(wǎng)絡(luò)

全卷積網(wǎng)絡(luò)(Fully Convolutional Network攀涵,F(xiàn)CN)是一種深度學(xué)習(xí)網(wǎng)絡(luò)架構(gòu)铣耘,用于解決圖像分割任務(wù)。與傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)用于圖像分類不同以故,F(xiàn)CN通過使用全卷積層來接受和生成任意大小的輸入和輸出蜗细,從而實現(xiàn)像素級別的圖像分割。

傳統(tǒng)的CNN網(wǎng)絡(luò)通常包含多個卷積層和池化層怒详,這些層的作用是逐漸減小特征圖的尺寸炉媒,以便進行最終的分類。然而昆烁,這種結(jié)構(gòu)無法產(chǎn)生與輸入圖像相同大小的輸出吊骤。為了解決這個問題,F(xiàn)CN引入了全卷積層静尼,用于替代傳統(tǒng)CNN網(wǎng)絡(luò)中的全連接層白粉。

全卷積層(Fully Convolutional Layer)是指在卷積神經(jīng)網(wǎng)絡(luò)中將全連接層替換為卷積層的操作。全卷積層使用1x1的卷積核鼠渺,保持特征圖的空間尺寸不變鸭巴,但可以改變特征圖的通道數(shù)。這樣一來拦盹,網(wǎng)絡(luò)的輸出將是一個與輸入圖像具有相同空間分辨率的特征圖鹃祖。

在FCN中,通過使用多個全卷積層掌敬,網(wǎng)絡(luò)可以逐步提取和學(xué)習(xí)不同尺度的特征表示惯豆。然后池磁,利用上采樣或反卷積操作將特征圖恢復(fù)到與輸入圖像相同的尺寸,得到像素級別的預(yù)測結(jié)果楷兽。為了提高分割精度地熄,通常還會在網(wǎng)絡(luò)中引入跳躍連接(Skip Connections),將不同層級的特征圖進行融合芯杀,以獲取更豐富的語義信息端考。

FCN廣泛應(yīng)用于語義分割、實例分割和語義分割的其他相關(guān)任務(wù)揭厚。通過使用FCN却特,可以實現(xiàn)對圖像中不同物體和區(qū)域的像素級別分割,為場景理解筛圆、目標檢測和圖像分析等任務(wù)提供強大的工具和方法裂明。

下面我們了解一下全卷積網(wǎng)絡(luò)模型最基本的設(shè)計,全卷積網(wǎng)絡(luò)先使用卷積神經(jīng)網(wǎng)絡(luò)抽取圖像特征,然后通過1\times 1卷積層將通道數(shù)變換為類別個數(shù)太援,通過轉(zhuǎn)置卷積層將特征圖的高和寬變換為輸入圖像的尺寸闽晦。 因此,模型輸出與輸入圖像的高和寬相同提岔,且最終輸出通道包含了該空間位置像素的類別預(yù)測仙蛉。

全卷積網(wǎng)絡(luò)

下面,我們使用在ImageNet數(shù)據(jù)集上預(yù)訓(xùn)練的ResNet-18模型來提取圖像特征碱蒙,并將該網(wǎng)絡(luò)記為pretrained_net荠瘪。 ResNet-18模型的最后幾層包括全局平均匯聚層和全連接層,然而全卷積網(wǎng)絡(luò)中不需要它們赛惩。

pretrained_net = torchvision.models.resnet18(pretrained=True)
list(pretrained_net.children())[-3:]
[Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (downsample): Sequential(
       (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
       (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
   )
   (1): BasicBlock(
     (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   )
 ),
 AdaptiveAvgPool2d(output_size=(1, 1)),
 Linear(in_features=512, out_features=1000, bias=True)]

接下來哀墓,我們創(chuàng)建一個全卷積網(wǎng)絡(luò)net。 它復(fù)制了ResNet-18中大部分的預(yù)訓(xùn)練層喷兼,除了最后的全局平均匯聚層和最接近輸出的全連接層麸祷。

net = nn.Sequential(*list(pretrained_net.children())[:-2])

給定高度為320和寬度為480的輸入,net的前向傳播將輸入的高和寬減小至原來的1/32褒搔,即10和15。

X = torch.rand(size=(1, 3, 320, 480))
net(X).shape
torch.Size([1, 512, 10, 15])

接下來使用1\times1卷積層將輸出通道數(shù)轉(zhuǎn)換為Pascal VOC2012數(shù)據(jù)集的類數(shù)(21類)喷面。 最后需要將特征圖的高度和寬度增加32倍星瘾,從而將其變回輸入圖像的高和寬。

卷積層輸出形狀的計算方法: 由于(320-64+16\times2+32)/32=10(480-64+16\times2+32)/32=15惧辈,我們構(gòu)造一個步幅為32的轉(zhuǎn)置卷積層琳状,并將卷積核的高和寬設(shè)為
64,填充為16盒齿。我們可以看到如果步幅為s念逞,填充為s/2(假設(shè)s/2是整數(shù))且卷積核的高和寬為2s困食,轉(zhuǎn)置卷積核會將輸入的高和寬分別放大s倍。

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))

初始化轉(zhuǎn)置卷積層

在圖像處理中翎承,我們有時需要將圖像放大硕盹,即上采樣(upsampling)。 雙線性插值(bilinear interpolation) 是常用的上采樣方法之一叨咖,它也經(jīng)常用于初始化轉(zhuǎn)置卷積層瘩例。
為了解釋雙線性插值,假設(shè)給定輸入圖像甸各,我們想要計算上采樣輸出圖像上的每個像素垛贤。


雙線性插值的上采樣可以通過轉(zhuǎn)置卷積層實現(xiàn),內(nèi)核由以下bilinear_kernel函數(shù)構(gòu)造趣倾。 限于篇幅聘惦,我們只給出bilinear_kernel函數(shù)的實現(xiàn),不討論算法的原理儒恋。

def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

讓我們用雙線性插值的上采樣實驗它由轉(zhuǎn)置卷積層實現(xiàn)善绎。 我們構(gòu)造一個將輸入的高和寬放大2倍的轉(zhuǎn)置卷積層,并將其卷積核用bilinear_kernel函數(shù)初始化碧浊。

conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
                                bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));

讀取圖像X涂邀,將上采樣的結(jié)果記作Y。為了打印圖像箱锐,我們需要調(diào)整通道維的位置比勉。

img = torchvision.transforms.ToTensor()(d2l.Image.open('../img/catdog.jpg'))
X = img.unsqueeze(0)
Y = conv_trans(X)
out_img = Y[0].permute(1, 2, 0).detach()

可以看到,轉(zhuǎn)置卷積層將圖像的高和寬分別放大了2倍驹止。 除了坐標刻度不同浩聋,雙線性插值放大的圖像和在之前打印出的原圖看上去沒什么兩樣。

d2l.set_figsize()
print('input image shape:', img.permute(1, 2, 0).shape)
d2l.plt.imshow(img.permute(1, 2, 0));
print('output image shape:', out_img.shape)
d2l.plt.imshow(out_img);
input image shape: torch.Size([561, 728, 3])
output image shape: torch.Size([1122, 1456, 3])

全卷積網(wǎng)絡(luò)用雙線性插值的上采樣初始化轉(zhuǎn)置卷積層臊恋。對于1\times 1卷積層衣洁,我們使用Xavier初始化參數(shù)。

W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);

讀取數(shù)據(jù)集

語義分割讀取數(shù)據(jù)集抖仅。 指定隨機裁剪的輸出圖像的形狀為320\times 480:高和寬都可以被32整除坊夫。

batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)

訓(xùn)練

現(xiàn)在我們可以訓(xùn)練全卷積網(wǎng)絡(luò)了。 這里的損失函數(shù)和準確率計算與圖像分類中的并沒有本質(zhì)上的不同撤卢,因為我們使用轉(zhuǎn)置卷積層的通道來預(yù)測像素的類別环凿,所以需要在損失計算中指定通道維。 此外放吩,模型基于每個像素的預(yù)測類別是否正確來計算準確率智听。

def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

預(yù)測

在預(yù)測時,我們需要將輸入圖像在各個通道做標準化,并轉(zhuǎn)成卷積神經(jīng)網(wǎng)絡(luò)所需要的四維輸入格式到推。

def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])

為了可視化預(yù)測的類別給每個像素考赛,我們將預(yù)測類別映射回它們在數(shù)據(jù)集中的標注顏色。

def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]

測試數(shù)據(jù)集中的圖像大小和形狀各異莉测。 由于模型使用了步幅為32的轉(zhuǎn)置卷積層颜骤,因此當輸入圖像的高或?qū)挓o法被32整除時,轉(zhuǎn)置卷積層輸出的高或?qū)挄c輸入圖像的尺寸有偏差悔雹。 為了解決這個問題复哆,我們可以在圖像中截取多塊高和寬為32的整數(shù)倍的矩形區(qū)域,并分別對這些區(qū)域中的像素做前向傳播腌零。 請注意梯找,這些區(qū)域的并集需要完整覆蓋輸入圖像。 當一個像素被多個區(qū)域所覆蓋時益涧,它在不同區(qū)域前向傳播中轉(zhuǎn)置卷積層輸出的平均值可以作為softmax運算的輸入锈锤,從而預(yù)測類別。
為簡單起見,我們只讀取幾張較大的測試圖像,并從圖像的左上角開始截取形狀為320\times480的區(qū)域用于預(yù)測疹鳄。 對于這些測試圖像恢总,我們逐一打印它們截取的區(qū)域伺通,再打印預(yù)測結(jié)果,最后打印標注的類別。

voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
    crop_rect = (0, 0, 320, 480)
    X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
    pred = label2image(predict(X))
    imgs += [X.permute(1,2,0), pred.cpu(),
             torchvision.transforms.functional.crop(
                 test_labels[i], *crop_rect).permute(1,2,0)]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);

小結(jié)

  • 全卷積網(wǎng)絡(luò)先使用卷積神經(jīng)網(wǎng)絡(luò)抽取圖像特征,然后通過1\times 1卷積層將通道數(shù)變換為類別個數(shù)呼巴,最后通過轉(zhuǎn)置卷積層將特征圖的高和寬變換為輸入圖像的尺寸。

  • 在全卷積網(wǎng)絡(luò)中御蒲,我們可以將轉(zhuǎn)置卷積層初始化為雙線性插值的上采樣衣赶。

風(fēng)格遷移

攝影愛好者也許接觸過濾波器。它能改變照片的顏色風(fēng)格厚满,從而使風(fēng)景照更加銳利或者令人像更加美白府瞄。但一個濾波器通常只能改變照片的某個方面。如果要照片達到理想中的風(fēng)格碘箍,可能需要嘗試大量不同的組合遵馆。這個過程的復(fù)雜程度不亞于模型調(diào)參。

風(fēng)格遷移(Style transfer)是一種計算機視覺技術(shù)丰榴,旨在將一幅圖像的風(fēng)格與另一幅圖像的內(nèi)容相結(jié)合团搞,生成一個新的圖像,使其看起來既保留原始內(nèi)容的特征多艇,又具有其他圖像的風(fēng)格。

風(fēng)格遷移的方法通诚裎牵基于神經(jīng)網(wǎng)絡(luò)峻黍,特別是卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks, CNN)复隆。下面是一個常見的風(fēng)格遷移過程的簡要步驟:

  1. 準備輸入圖像:選擇一幅作為內(nèi)容圖像和一幅作為風(fēng)格圖像。

  2. 構(gòu)建預(yù)訓(xùn)練模型:使用預(yù)訓(xùn)練的卷積神經(jīng)網(wǎng)絡(luò)模型(如VGGNet)作為基礎(chǔ)模型姆涩。該模型已經(jīng)在大規(guī)模圖像數(shù)據(jù)集上進行了訓(xùn)練挽拂,具有學(xué)習(xí)圖像特征的能力。

  3. 提取特征:將內(nèi)容圖像和風(fēng)格圖像輸入到卷積神經(jīng)網(wǎng)絡(luò)中骨饿,并提取出它們在不同層次的特征表示亏栈。

  4. 計算內(nèi)容損失:通過比較內(nèi)容圖像和生成圖像在某些中間層的特征表示,計算內(nèi)容損失宏赘,用于確保生成圖像保留了內(nèi)容圖像的特征绒北。

  5. 計算風(fēng)格損失:使用風(fēng)格圖像和生成圖像的特征表示,計算它們之間的風(fēng)格差異察署。這通常通過計算它們的協(xié)方差矩陣或Gram矩陣來實現(xiàn)闷游。

  6. 定義總損失函數(shù):將內(nèi)容損失和風(fēng)格損失加權(quán)相加,得到總的損失函數(shù)贴汪。

  7. 優(yōu)化生成圖像:通過最小化總損失函數(shù)脐往,使用梯度下降等優(yōu)化算法來更新生成圖像的像素值,使其逐漸接近目標圖像的內(nèi)容和風(fēng)格扳埂。

  8. 迭代優(yōu)化:重復(fù)執(zhí)行第7步业簿,直到生成圖像達到滿意的效果或達到指定的迭代次數(shù)。

通過上述步驟阳懂,風(fēng)格遷移算法可以生成具有內(nèi)容圖像特征和風(fēng)格圖像風(fēng)格的新圖像梅尤。這種技術(shù)在藝術(shù)創(chuàng)作、圖像處理和視覺效果等領(lǐng)域有廣泛應(yīng)用希太,可以用于圖像風(fēng)格化克饶、圖像增強、圖像合成等任務(wù)誊辉。

基于卷積神經(jīng)網(wǎng)絡(luò)的風(fēng)格遷移矾湃。實線箭頭和虛線箭頭分別表示前向傳播和反向傳播

接下來,我們通過前向傳播(實線箭頭方向)計算風(fēng)格遷移的損失函數(shù)堕澄,并通過反向傳播(虛線箭頭方向)迭代模型參數(shù)邀跃,即不斷更新合成圖像。 風(fēng)格遷移常用的損失函數(shù)由3部分組成:

  1. 內(nèi)容損失使合成圖像與內(nèi)容圖像在內(nèi)容特征上接近蛙紫;

  2. 風(fēng)格損失使合成圖像與風(fēng)格圖像在風(fēng)格特征上接近拍屑;

  3. 全變分損失則有助于減少合成圖像中的噪點。

最后坑傅,當模型訓(xùn)練結(jié)束時僵驰,我們輸出風(fēng)格遷移的模型參數(shù),即得到最終的合成圖像。

在下面蒜茴,我們將通過代碼來進一步了解風(fēng)格遷移的技術(shù)細節(jié)星爪。

閱讀內(nèi)容和風(fēng)格圖像

首先,我們讀取內(nèi)容和風(fēng)格圖像粉私。 從打印出的圖像坐標軸可以看出顽腾,它們的尺寸并不一樣。

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);
style_img = d2l.Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

預(yù)處理和后處理

下面诺核,定義圖像的預(yù)處理函數(shù)和后處理函數(shù)抄肖。 預(yù)處理函數(shù)preprocess對輸入圖像在RGB三個通道分別做標準化,并將結(jié)果變換成卷積神經(jīng)網(wǎng)絡(luò)接受的輸入格式窖杀。 后處理函數(shù)postprocess則將輸出圖像中的像素值還原回標準化之前的值漓摩。 由于圖像打印函數(shù)要求每個像素的浮點數(shù)值在0~1之間,我們對小于0和大于1的值分別取0和1陈瘦。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

抽取圖像特征

我們使用基于ImageNet數(shù)據(jù)集預(yù)訓(xùn)練的VGG-19模型來抽取圖像特征 幌甘。

pretrained_net = torchvision.models.vgg19(pretrained=True)

為了抽取圖像的內(nèi)容特征和風(fēng)格特征,我們可以選擇VGG網(wǎng)絡(luò)中某些層的輸出痊项。 一般來說锅风,越靠近輸入層,越容易抽取圖像的細節(jié)信息鞍泉;反之皱埠,則越容易抽取圖像的全局信息。 為了避免合成圖像過多保留內(nèi)容圖像的細節(jié)咖驮,我們選擇VGG較靠近輸出的層边器,即內(nèi)容層,來輸出圖像的內(nèi)容特征托修。我們還從VGG中選擇不同層的輸出來匹配局部和全局的風(fēng)格忘巧,這些圖層也稱為風(fēng)格層。VGG網(wǎng)絡(luò)使用了5個卷積塊睦刃。 實驗中砚嘴,我們選擇第四卷積塊的最后一個卷積層作為內(nèi)容層,選擇每個卷積塊的第一個卷積層作為風(fēng)格層涩拙。 這些層的索引可以通過打印pretrained_net實例獲取际长。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

使用VGG層抽取特征時,我們只需要用到從輸入層到最靠近輸出層的內(nèi)容層或風(fēng)格層之間的所有層兴泥。 下面構(gòu)建一個新的網(wǎng)絡(luò)net工育,它只保留需要用到的VGG的所有層。

net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

給定輸入X搓彻,如果我們簡單地調(diào)用前向傳播net(X)如绸,只能獲得最后一層的輸出嘱朽。 由于我們還需要中間層的輸出,因此這里我們逐層計算怔接,并保留內(nèi)容層和風(fēng)格層的輸出燥翅。

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

下面定義兩個函數(shù):get_contents函數(shù)對內(nèi)容圖像抽取內(nèi)容特征; get_styles函數(shù)對風(fēng)格圖像抽取風(fēng)格特征蜕提。 因為在訓(xùn)練時無須改變預(yù)訓(xùn)練的VGG的模型參數(shù),所以我們可以在訓(xùn)練開始之前就提取出內(nèi)容特征和風(fēng)格特征靶端。 由于合成圖像是風(fēng)格遷移所需迭代的模型參數(shù)谎势,我們只能在訓(xùn)練過程中通過調(diào)用extract_features函數(shù)來抽取合成圖像的內(nèi)容特征和風(fēng)格特征。

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

定義損失函數(shù)

下面我們來描述風(fēng)格遷移的損失函數(shù)杨名。 它由內(nèi)容損失脏榆、風(fēng)格損失和全變分損失3部分組成。

內(nèi)容損失

與線性回歸中的損失函數(shù)類似台谍,內(nèi)容損失通過平方誤差函數(shù)衡量合成圖像與內(nèi)容圖像在內(nèi)容特征上的差異须喂。 平方誤差函數(shù)的兩個輸入均為extract_features函數(shù)計算所得到的內(nèi)容層的輸出。

def content_loss(Y_hat, Y):
    # 我們從動態(tài)計算梯度的樹中分離目標:
    # 這是一個規(guī)定的值趁蕊,而不是一個變量坞生。
    return torch.square(Y_hat - Y.detach()).mean()

風(fēng)格損失

def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

自然地,風(fēng)格損失的平方誤差函數(shù)的兩個格拉姆矩陣輸入分別基于合成圖像與風(fēng)格圖像的風(fēng)格層輸出掷伙。這里假設(shè)基于風(fēng)格圖像的格拉姆矩陣gram_Y已經(jīng)預(yù)先計算好了是己。

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

全變分損失

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

損失函數(shù)

風(fēng)格轉(zhuǎn)移的損失函數(shù)是內(nèi)容損失、風(fēng)格損失和總變化損失的加權(quán)和任柜。 通過調(diào)節(jié)這些權(quán)重超參數(shù)卒废,我們可以權(quán)衡合成圖像在保留內(nèi)容、遷移風(fēng)格以及去噪三方面的相對重要性宙地。

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分別計算內(nèi)容損失摔认、風(fēng)格損失和全變分損失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 對所有損失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

初始化合成圖像

在風(fēng)格遷移中,合成的圖像是訓(xùn)練期間唯一需要更新的變量宅粥。因此参袱,我們可以定義一個簡單的模型SynthesizedImage,并將合成的圖像視為模型參數(shù)粹胯。模型的前向傳播只需返回模型參數(shù)即可蓖柔。

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

下面,我們定義get_inits函數(shù)风纠。該函數(shù)創(chuàng)建了合成圖像的模型實例况鸣,并將其初始化為圖像X。風(fēng)格圖像在各個風(fēng)格層的格拉姆矩陣styles_Y_gram將在訓(xùn)練前預(yù)先計算好竹观。

def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

訓(xùn)練模型

在訓(xùn)練模型進行風(fēng)格遷移時镐捧,我們不斷抽取合成圖像的內(nèi)容特征和風(fēng)格特征潜索,然后計算損失函數(shù)。下面定義了訓(xùn)練循環(huán)懂酱。

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

現(xiàn)在我們訓(xùn)練模型: 首先將內(nèi)容圖像和風(fēng)格圖像的高和寬分別調(diào)整為300和450像素竹习,用內(nèi)容圖像來初始化合成圖像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

我們可以看到列牺,合成圖像保留了內(nèi)容圖像的風(fēng)景和物體整陌,并同時遷移了風(fēng)格圖像的色彩。例如瞎领,合成圖像具有與風(fēng)格圖像中一樣的色彩塊泌辫,其中一些甚至具有畫筆筆觸的細微紋理。

小結(jié)

  • 風(fēng)格遷移常用的損失函數(shù)由3部分組成:(1)內(nèi)容損失使合成圖像與內(nèi)容圖像在內(nèi)容特征上接近九默;(2)風(fēng)格損失令合成圖像與風(fēng)格圖像在風(fēng)格特征上接近震放;(3)全變分損失則有助于減少合成圖像中的噪點。

  • 我們可以通過預(yù)訓(xùn)練的卷積神經(jīng)網(wǎng)絡(luò)來抽取圖像的特征驼修,并通過最小化損失函數(shù)來不斷更新合成圖像來作為模型參數(shù)殿遂。

  • 我們使用格拉姆矩陣表達風(fēng)格層輸出的風(fēng)格。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末乙各,一起剝皮案震驚了整個濱河市墨礁,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌觅丰,老刑警劉巖饵溅,帶你破解...
    沈念sama閱讀 219,188評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異妇萄,居然都是意外死亡蜕企,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,464評論 3 395
  • 文/潘曉璐 我一進店門冠句,熙熙樓的掌柜王于貴愁眉苦臉地迎上來轻掩,“玉大人,你說我怎么就攤上這事懦底〈侥粒” “怎么了?”我有些...
    開封第一講書人閱讀 165,562評論 0 356
  • 文/不壞的土叔 我叫張陵聚唐,是天一觀的道長丐重。 經(jīng)常有香客問我,道長杆查,這世上最難降的妖魔是什么扮惦? 我笑而不...
    開封第一講書人閱讀 58,893評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮亲桦,結(jié)果婚禮上崖蜜,老公的妹妹穿的比我還像新娘浊仆。我一直安慰自己,他們只是感情好豫领,可當我...
    茶點故事閱讀 67,917評論 6 392
  • 文/花漫 我一把揭開白布抡柿。 她就那樣靜靜地躺著,像睡著了一般等恐。 火紅的嫁衣襯著肌膚如雪洲劣。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,708評論 1 305
  • 那天课蔬,我揣著相機與錄音闪檬,去河邊找鬼。 笑死购笆,一個胖子當著我的面吹牛,可吹牛的內(nèi)容都是我干的虚循。 我是一名探鬼主播同欠,決...
    沈念sama閱讀 40,430評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼横缔!你這毒婦竟也來了铺遂?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,342評論 0 276
  • 序言:老撾萬榮一對情侶失蹤茎刚,失蹤者是張志新(化名)和其女友劉穎襟锐,沒想到半個月后,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體膛锭,經(jīng)...
    沈念sama閱讀 45,801評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡粮坞,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,976評論 3 337
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了初狰。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片莫杈。...
    茶點故事閱讀 40,115評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖奢入,靈堂內(nèi)的尸體忽然破棺而出筝闹,到底是詐尸還是另有隱情,我是刑警寧澤腥光,帶...
    沈念sama閱讀 35,804評論 5 346
  • 正文 年R本政府宣布关顷,位于F島的核電站,受9級特大地震影響武福,放射性物質(zhì)發(fā)生泄漏议双。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,458評論 3 331
  • 文/蒙蒙 一艘儒、第九天 我趴在偏房一處隱蔽的房頂上張望聋伦。 院中可真熱鬧夫偶,春花似錦、人聲如沸觉增。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,008評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽逾礁。三九已至说铃,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間嘹履,已是汗流浹背腻扇。 一陣腳步聲響...
    開封第一講書人閱讀 33,135評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留砾嫉,地道東北人幼苛。 一個月前我還...
    沈念sama閱讀 48,365評論 3 373
  • 正文 我出身青樓,卻偏偏與公主長得像焕刮,于是被迫代替她去往敵國和親舶沿。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,055評論 2 355

推薦閱讀更多精彩內(nèi)容