腎小球的病理圖像分割

項目目標

在不同的組織制備管道中分割人類腎臟組織圖像中的腎小球區(qū)域。腎小球是一種功能組織單位(FTU):以毛細血管為中心的三維細胞塊仪芒,因此該塊中的每個細胞與同一塊中的任何其他細胞都在擴散距離之內。

項目數(shù)據(jù)

提供的數(shù)據(jù)包括11張新鮮冷凍和9張福爾馬林固定石蠟包埋(FFPE)PAS腎臟圖像:8 張用于訓練,5+7張用于測試。每個都有大約50k像素大小兰珍,并保存為高分辨率tiff圖像侍郭。為了使如此大的圖像適合神經(jīng)網(wǎng)絡的訓練,必須將它們切成小塊掠河。根據(jù)檢測到的目標大小亮元,此數(shù)據(jù)的適當圖塊大小應為 1024*1024。對此使用分辨率低4倍的256*256瓦片(tiles)唠摹,可以在最終設置上運行更高分辨率的瓦片爆捞。瓦片數(shù)(8211+1893)

數(shù)據(jù)處理辦法

重疊裁剪

Overlap-tile策略搭配patch(圖像分塊)一起使用。當內存資源有限從而無法對整張大圖進行預測時勾拉,可以對圖像先進行鏡像padding煮甥,然后按序將padding后的圖像分割成固定大小的patch盗温。這樣,能夠實現(xiàn)對任意大的圖像進行無縫分割成肘,同時每個圖像塊也獲得了相應的上下文信息卖局。另外,在數(shù)據(jù)量較少的情況下双霍,每張圖像都被分割成多個patch砚偶,相當于起到了擴充數(shù)據(jù)量的作用。更重要的是洒闸,這種策略不需要對原圖進行縮放染坯,每個位置的像素值與原圖保持一致,不會因為縮放而帶來誤差丘逸。overlap-tile策略的思想是:對圖像的某一塊像素點(黃框內部分)進行預測時单鹿,需要該圖像塊周圍的像素點(藍色框內)提供上下文信息(context),以獲得更準確的預測鸣个。

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

數(shù)據(jù)增強策略

本項目用到的操作包括模糊圖像羞反、中心模糊、高斯噪聲囤萤、色調飽和度值昼窗、對比度受限自適應直方圖均衡、隨機亮度對比度等涛舍,以及常用的翻轉澄惊、旋轉、仿射變換富雅。在訓練集上只使用旋轉掸驱、翻轉變換柏卤。

def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9,
                         border_mode=cv2.BORDER_REFLECT),
        OneOf([
            ElasticTransform(p=.3),
            GaussianBlur(p=.3),
            GaussNoise(p=.3),
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            # IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(15,25,0),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
        ], p=0.3),
    ], p=p)

badcase分析顯示有些更暗更小的東西與正常腎小球不相似吟税。它們在切片的邊界上分布得更密集弟跑,并且在這些結構中似乎有更少的細胞核饶唤,是纖維狀新月形腎小球蒲每。通過調整數(shù)據(jù)增強策略可以有一定的改善垒拢。

Dataset類

mean = np.array([0.7720342, 0.74582646, 0.76392896])
std = np.array([0.24745085, 0.26182273, 0.25782376])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, path, fold=0, train=True, tfms=None, seed=2020, nfolds= 4, include_pl=False):
        self.path=path
       
        if include_pl:
            ids = np.concatenate([pd.read_csv(os.path.join(self.path,'train.csv')).id.values,
                     pd.read_csv(os.path.join(self.path,'sample_submission.csv')).id.values])
        else:
            ids = pd.read_csv(os.path.join(self.path,'train.csv')).id.values      
        kf = KFold(n_splits=nfolds,random_state=seed,shuffle=True)
        ids = set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        print(f"number of {'train' if train else 'val'} images is {len(ids)}")
        
        self.fnames = ['train/'+fname for fname in os.listdir(os.path.join(self.path,'train')) if int(fname.split('_')[0]) in ids]
        # +['test/'+fname for fname in os.listdir(os.path.join(self.path,'test')) if fname.split('_')[0] in ids]

        self.train = train
        self.tfms = tfms

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(self.path,fname)), cv2.COLOR_BGR2RGB)

        if self.fnames[idx][:5]=='train':
            mask = cv2.imread(os.path.join(self.path,'masks',fname[6:]),cv2.IMREAD_GRAYSCALE)
        else:
            mask = cv2.imread(os.path.join(self.path,'test_masks',fname[5:]),cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']

        data={'img':img2tensor((img/255.0 - mean)/std), 'mask':img2tensor(mask)}
        return data

模型設計

使用的模型基于一個 U 形網(wǎng)絡(UneXt50擂送,見下圖)翘悉。 Unet 架構:編碼器部分創(chuàng)建不同級別的特征表示啤贩,而解碼器將特征組合并生成預測作為分割掩碼待秃。編碼器和解碼器之間的跳過連接允許有效地利用編碼器中間卷積層的特征,而無需信息通過整個編碼器和解碼器痹屹。后者對于將預測掩碼鏈接到檢測對象的特定像素特別重要章郁。后來人們意識到 ImageNet 預訓練的計算機視覺模型可以顯著提高分割模型的質量,因為編碼器的架構經(jīng)過優(yōu)化志衍,編碼器容量高(與原始 Unet 中使用的編碼器相比)暖庄,以及具有遷移學習的強大功能聊替。

原始Unet
本項目設計

使用半監(jiān)督 Imagenet 預訓練的 ResNeXt50 模型作為主干。 在 Pytorch 中雄驹,它提供了 EfficientNet B2-B3 的性能佃牛,在計算成本上具有更快的收斂速度,以及EfficientNet B0 的 GPU RAM 要求医舆。

對 ResNet 有效性的解釋主要有三種:

  • 使網(wǎng)絡更容易在某些層學到恒等變換(identity mapping)俘侠。在某些層執(zhí)行恒等變換是一種構造性解,使更深的模型的性能至少不低于較淺的模型蔬将。這也是作者原始論文指出的動機爷速。(ResNet解決了深網(wǎng)絡的梯度問題,自然能學習到更多抽象特征霞怀,所以效果好還是因為夠深惫东。)
    [1512.03385] Deep Residual Learning for Image Recognition
  • 殘差網(wǎng)絡是很多淺層網(wǎng)絡的集成(ensemble),層數(shù)的指數(shù)級那么多毙石。主要的實驗證據(jù)是:把 ResNet 中的某些層直接刪掉廉沮,模型的性能幾乎不下降。
    [1605.06431] Residual Networks Behave Like Ensembles of Relatively Shallow Networks
  • 殘差網(wǎng)絡使信息更容易在各層之間流動徐矩,包括在前向傳播時提供特征重用滞时,在反向傳播時緩解梯度信號消失

ResNeXt 同時采用 VGG 堆疊的思想Inception 的 split-transform-merge 思想滤灯。ResNeXt 提出的主要原因在于:傳統(tǒng)的要提高模型的準確率坪稽,都是加深或加寬網(wǎng)絡,但是隨著超參數(shù)數(shù)量的增加(比如channels數(shù)鳞骤,filter size等等)窒百,網(wǎng)絡設計的難度和計算開銷也會增加。因此ResNeXt 結構可以在不增加參數(shù)復雜度的前提下提高準確率豫尽,同時還減少了超參數(shù)的數(shù)量篙梢。
一般增強一個CNN的表達能力有三種手段:一是增加網(wǎng)絡層次即加深網(wǎng)絡二是增加網(wǎng)絡模塊寬度美旧;三是改善CNN網(wǎng)絡結構設計)渤滞。ResNeXt的做法可歸為上面三種方法的第三種。它引入了新的用于構建CNN網(wǎng)絡的模塊陈症,提出了一個cardinatity的概念,用于作為模型復雜度的另外一個度量震糖。Cardinatity指的是一個block中所具有的相同分支的數(shù)目录肯。作者進行了一系列對比實驗,有力證明在保證相似計算復雜度及模型參數(shù)大小的前提下吊说,提升cardinatity比提升height或width可取得更好的模型表達能力论咏。下面三種ResNeXt網(wǎng)絡模塊的變形优炬。它們在數(shù)學計算上是完全等價的,而第三種包含有Group convolution操作的正是最終ResNeXt網(wǎng)絡所采用的操作厅贪。

ResNeXt的分類效果為什么比Resnet好?
ResNeXt的精妙之處在于蠢护,該思路沿用到nlp里就有了multi-head attention。
第一养涮,ResNext中引入cardinality葵硕,實際上仍然還是一個Group的概念。不同的組之間實際上是不同的subspace贯吓,而他們的確能學到更diverse的表示懈凹。
第二,這種分組的操作或許能起到網(wǎng)絡正則化的作用悄谐。實際上介评,增加一個cardinality維度之后,會使得卷積核學到的關系更加稀疏爬舰。同時在整體的復雜度不變的情況下们陆,其中Network-in-Neuron的思想,會大大降低了每個sub-network的復雜度情屹,那么其過擬合的風險相比于ResNet也將會大大降低坪仇。

class UneXt(nn.Module):
    def __init__(self, m, stride=1, **kwargs):
        super().__init__()
        #encoder
        # m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
        #                    'resnext101_32x4d_swsl')
#         m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models',
#                            'resnext50_32x4d_swsl', pretrained=False)
        #m = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=4)
        #m = torchvision.models.resnext50_32x4d(pretrained=False)
        # m = torch.hub.load(
        #     'moskomule/senet.pytorch',
        #     'se_resnet101',
        #     pretrained=True,)

        #m=torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        self.enc0 = nn.Sequential(m.conv1, m.bn1, nn.ReLU(inplace=True))
        self.enc1 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1),
                            m.layer1) #256
        self.enc2 = m.layer2 #512
        self.enc3 = m.layer3 #1024
        self.enc4 = m.layer4 #2048
        #aspp with customized dilatations
        self.aspp = ASPP(2048,256,out_c=512,dilations=[stride*1,stride*2,stride*3,stride*4])
        self.drop_aspp = nn.Dropout2d(0.5)
        #decoder
        self.dec4 = UnetBlock(512,1024,256)
        self.dec3 = UnetBlock(256,512,128)
        self.dec2 = UnetBlock(128,256,64)
        self.dec1 = UnetBlock(64,64,32)
        self.fpn = FPN([512,256,128,64],[16]*4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32+16*4, 1, ks=1, norm_type=None, act_cls=None)

    def forward(self, x):
        enc0 = self.enc0(x)
        enc1 = self.enc1(enc0)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5),enc3)
        dec2 = self.dec3(dec3,enc2)
        dec1 = self.dec2(dec2,enc1)
        dec0 = self.dec1(dec1,enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x,scale_factor=2,mode='bilinear')
        return x

class UnetBlock(nn.Module):
    def __init__(self, up_in_c:int, x_in_c:int, nf:int=None, blur:bool=False,
                 self_attention:bool=False, **kwargs):
        super().__init__()
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        self.bn = nn.BatchNorm2d(x_in_c)
        ni = up_in_c//2 + x_in_c
        nf = nf if nf is not None else max(up_in_c//2,32)
        self.conv1 = ConvLayer(ni, nf, norm_type=None, **kwargs)
        self.conv2 = ConvLayer(nf, nf, norm_type=None,
            xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, up_in:Tensor, left_in:Tensor) -> Tensor:
        s = left_in
        up_out = self.shuf(up_in)
        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
        return self.conv2(self.conv1(cat_x))

也嘗試使用了Efficientnet作為encoder構建Unet網(wǎng)絡

pretrained_root = '/home/ruanshijian/hubmap/'
efficient_net_encoders = {
    "efficientnet-b0": {
        "out_channels": (3, 32, 24, 40, 112, 320),
        "stage_idxs": (3, 5, 9, 16),
        "weight_path": pretrained_root + "efficientnet-b0-08094119.pth"
    },
    "efficientnet-b1": {
        "out_channels": (3, 32, 24, 40, 112, 320),
        "stage_idxs": (5, 8, 16, 23),
        "weight_path": pretrained_root + "efficientnet-b1-dbc7070a.pth"
    },
    "efficientnet-b2": {
        "out_channels": (3, 32, 24, 48, 120, 352),
        "stage_idxs": (5, 8, 16, 23),
        "weight_path": pretrained_root + "efficientnet-b2-27687264.pth"
    },
    "efficientnet-b3": {
        "out_channels": (3, 40, 32, 48, 136, 384),
        "stage_idxs": (5, 8, 18, 26),
        "weight_path": pretrained_root + "efficientnet-b3-c8376fa2.pth"
    },
    "efficientnet-b4": {
        "out_channels": (3, 48, 32, 56, 160, 448),
        "stage_idxs": (6, 10, 22, 32),
        "weight_path": pretrained_root + "efficientnet-b4-e116e8b3.pth"
    },
    "efficientnet-b5": {
        "out_channels": (3, 48, 40, 64, 176, 512),
        "stage_idxs": (8, 13, 27, 39),
        "weight_path": pretrained_root + "efficientnet-b5-586e6cc6.pth"
    },
    "efficientnet-b6": {
        "out_channels": (3, 56, 40, 72, 200, 576),
        "stage_idxs": (9, 15, 31, 45),
        "weight_path": pretrained_root + "efficientnet-b6-c76e70fd.pth"
    },
    "efficientnet-b7": {
        "out_channels": (3, 64, 48, 80, 224, 640),
        "stage_idxs": (11, 18, 38, 55),
        "weight_path": pretrained_root + "efficientnet-b7-dcc49843.pth"
    }
}

import sys
sys.path.insert(0, '/home/ruanshijian/hubmap/EfficientNet-PyTorch')
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import get_model_params

class EfficientNetEncoder(EfficientNet):
    def __init__(self, stage_idxs, out_channels, model_name, depth=5):

        blocks_args, global_params = get_model_params(model_name, override_params=None)
        super().__init__(blocks_args, global_params)

        cfg = efficient_net_encoders[model_name]

        self._stage_idxs = stage_idxs
        self._out_channels = out_channels
        self._depth = depth
        self._in_channels = 3

        del self._fc
        self.load_state_dict(torch.load(cfg['weight_path']))

    def get_stages(self):
        return [
            nn.Identity(),
            nn.Sequential(self._conv_stem, self._bn0, self._swish),
            self._blocks[:self._stage_idxs[0]],
            self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
            self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
            self._blocks[self._stage_idxs[2]:],
        ]

    def forward(self, x):
        stages = self.get_stages()

        block_number = 0.
        drop_connect_rate = self._global_params.drop_connect_rate

        features = []
        for i in range(self._depth + 1):

            # Identity and Sequential stages
            if i < 2:
                x = stages[i](x)

            # Block stages need drop_connect rate
            else:
                for module in stages[i]:
                    drop_connect = drop_connect_rate * block_number / len(self._blocks)
                    block_number += 1.
                    x = module(x, drop_connect)

            features.append(x)

        return features

    def load_state_dict(self, state_dict, **kwargs):
        state_dict.pop("_fc.bias")
        state_dict.pop("_fc.weight")
        super().load_state_dict(state_dict, **kwargs)


class EffUnet(nn.Module):
    def __init__(self, model_name, stride=1):
        super().__init__()

        cfg = efficient_net_encoders[model_name]
        stage_idxs = cfg['stage_idxs']
        out_channels = cfg['out_channels']

        self.encoder = EfficientNetEncoder(stage_idxs, out_channels, model_name)

        # aspp with customized dilatations
        self.aspp = ASPP(out_channels[-1], 256, out_c=384,
                         dilations=[stride * 1, stride * 2, stride * 3, stride * 4])
        self.drop_aspp = nn.Dropout2d(0.5)
        # decoder
        self.dec4 = UnetBlock(384, out_channels[-2], 256)
        self.dec3 = UnetBlock(256, out_channels[-3], 128)
        self.dec2 = UnetBlock(128, out_channels[-4], 64)
        self.dec1 = UnetBlock(64, out_channels[-5], 32)
        self.fpn = FPN([384, 256, 128, 64], [16] * 4)
        self.drop = nn.Dropout2d(0.1)
        self.final_conv = ConvLayer(32 + 16 * 4, 1, ks=1, norm_type=None, act_cls=None)

    def forward(self, x):
        enc0, enc1, enc2, enc3, enc4 = self.encoder(x)[-5:]
        enc5 = self.aspp(enc4)
        dec3 = self.dec4(self.drop_aspp(enc5), enc3)
        dec2 = self.dec3(dec3, enc2)
        dec1 = self.dec2(dec2, enc1)
        dec0 = self.dec1(dec1, enc0)
        x = self.fpn([enc5, dec3, dec2, dec1], dec0)
        x = self.final_conv(self.drop(x))
        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        return x

PixelShuffle是一種上采樣方法,可以對縮小后的特征圖進行有效的放大屁商⊙毯埽可以替代插值或解卷積的方法實現(xiàn)upscale。pixelshuffle算法的實現(xiàn)流程如圖蜡镶,其實現(xiàn)的功能是:將一個H × W的低分辨率輸入圖像(Low Resolution)雾袱,通過Sub-pixel操作將其變?yōu)閞H*rW的高分辨率圖像(High Resolution)。但是其實現(xiàn)過程不是直接通過插值等方式產(chǎn)生這個高分辨率圖像官还,而是通過卷積先得到r^2個通道的特征圖(特征圖大小和輸入低分辨率圖像一致)芹橡,然后通過周期篩選(periodic shuffing)的方法得到這個高分辨率的圖像,其中r為上采樣因子(upscaling factor)望伦,也就是圖像的擴大倍率林说。

簡單一句話,PixelShuffle層做的事情就是將輸入feature map像素重組輸出高分辨率的feature map屯伞,是一種上采樣方法腿箩,具體表達為:N*(C*r*r)*W*H---->>N*C*(H*r)*(W*r)

  1. upsample是利用傳統(tǒng)插值方法進行上采樣。往往會在upsample后接一個conv劣摇,進行學習珠移。任務:超分,目標檢測。
  2. 轉置卷積應該是上采樣力度最大的钧惧,所以有些時候的結果看起來會不太真實暇韧。任務:GAN,分割浓瞪,超分懈玻。
  3. pixel shuffle最開始也是用在超分上的,把channel通道放大r^2倍乾颁,然后再分給H涂乌,W成rH,rW钮孵,達到上采樣的效果骂倘。目前超分用這個應該是主流。任務:超分巴席。

此外历涝,在ASPP模塊中還加入了OC注意力模塊

class BaseOC_Module(nn.Module):
    """
    Implementation of the BaseOC module
    Parameters:
        in_features / out_features: the channels of the input / output feature maps.
        dropout: we choose 0.05 as the default value.
        size: you can apply multiple sizes. Here we only use one size.
    Return:
        features fused with Object context information.
    """

    def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
        super(BaseOC_Module, self).__init__()
        self.stages = []
        self.stages = nn.ModuleList(
            [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout2d(dropout)
        )

    def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
        return SelfAttentionBlock2D(in_channels,
                                    key_channels,
                                    value_channels,
                                    output_channels,
                                    size)

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]
        output = self.conv_bn_dropout(torch.cat([context, feats], 1))
        return output

class BaseOC_Context_Module(nn.Module):
    """
    Output only the context features.
    Parameters:
        in_features / out_features: the channels of the input / output feature maps.
        dropout: specify the dropout ratio
        fusion: We provide two different fusion method, "concat" or "add"
        size: we find that directly learn the attention weights on even 1/8 feature maps is hard.
    Return:
        features after "concat" or "add"
    """

    def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
        super(BaseOC_Context_Module, self).__init__()
        self.stages = []
        self.stages = nn.ModuleList(
            [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
        self.conv_bn_dropout = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
        return SelfAttentionBlock2D(in_channels,
                                    key_channels,
                                    value_channels,
                                    output_channels,
                                    size)

    def forward(self, feats):
        priors = [stage(feats) for stage in self.stages]
        context = priors[0]
        for i in range(1, len(priors)):
            context += priors[i]
        output = self.conv_bn_dropout(context)
        return output

添加了特征金字塔網(wǎng)絡(FPN):解碼器的不同上采樣塊和輸出層之間的附加跳過連接。因此漾唉,最終預測是基于 U-net 輸出與中間層調整大小的輸出串聯(lián)接產(chǎn)生的荧库。這些跳躍連接為梯度傳導提供了捷徑以提高模型性能和收斂速度。由于中間層有許多通道赵刑,它們的上采樣和用作最后一層的輸入會在計算時間和內存方面引入大量開銷分衫。因此,在調整大小之前應用 3*3+3*3 卷積(分解)以減少通道數(shù)般此。淺層的網(wǎng)絡更關注于細節(jié)信息蚪战,高層的網(wǎng)絡更關注于語義信息,而高層的語義信息能夠幫助我們準確的檢測出目標铐懊,設計思想就是同時利用低層特征和高層特征邀桑,分別在不同的層同時進行預測,這是因為一幅圖像中可能具有多個不同大小的目標科乎,區(qū)分不同的目標可能需要不同的特征壁畸,對于簡單的目標僅僅需要淺層的特征就可以檢測到它,對于復雜的目標就需要利用復雜的特征來檢測它茅茂。整個過程就是首先在原始圖像上面進行深度卷積捏萍,然后分別在不同的特征層上面進行預測。它的優(yōu)點是在不同的層上面輸出對應的目標空闲,不需要經(jīng)過所有的層才輸出對應的目標(即對于有些目標來說令杈,不需要進行多余的前向操作),這樣可以在一定程度上對網(wǎng)絡進行加速操作碴倾,同時可以提高算法的檢測性能逗噩。它的缺點是獲得的特征不魯棒悔常,都是一些弱特征(因為很多的特征都是從較淺的層獲得的)。

class FPN(nn.Module):
    def __init__(self, input_channels:list, output_channels:list):
        super().__init__()
        self.convs = nn.ModuleList(
            [nn.Sequential(nn.Conv2d(in_ch, out_ch*2, kernel_size=3, padding=1),
             nn.ReLU(inplace=True), nn.BatchNorm2d(out_ch*2),
             nn.Conv2d(out_ch*2, out_ch, kernel_size=3, padding=1))
            for in_ch, out_ch in zip(input_channels, output_channels)])

    def forward(self, xs:list, last_layer):
        hcs = [F.interpolate(c(x),scale_factor=2**(len(self.convs)-i),mode='bilinear')
               for i,(c,x) in enumerate(zip(self.convs, xs))]
        hcs.append(last_layer)
        return torch.cat(hcs, dim=1)

在編碼器和解碼器之間添加的 Atrous Spatial Pyramid Pooling (ASPP) 塊给赞。傳統(tǒng) U 形網(wǎng)絡的缺陷是由一個小的感受野造成的。因此矫户,如果模型需要對大對象的分割做出決定片迅,特別是對于大圖像分辨率,它可能會因為只能查看對象的一部分而感到困惑皆辽。增加感受野并實現(xiàn)圖像不同部分之間交互的一種方法是使用具有不同擴張的卷積塊組合(在 ASPP 塊中具有不同速率的 Atrous 卷積)柑蛇。雖然原始論文使用 6、12驱闷、18 速率耻台,但它們可以針對特定任務和特定圖像分辨率進行定制,以最大限度地提高性能空另。另外在 ASPP 塊中使用分組卷積來減少模型參數(shù)的數(shù)量盆耽。

class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, groups=1):
        super().__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                stride=1, padding=padding, dilation=dilation, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=512, mid_c=256, dilations=[6, 12, 18, 24], out_c=None):
        super().__init__()
        self.aspps = [_ASPPModule(inplanes, mid_c, 1, padding=0, dilation=1)] + \
            [_ASPPModule(inplanes, mid_c, 3, padding=d, dilation=d,groups=4) for d in dilations]
        self.aspps = nn.ModuleList(self.aspps)
        self.global_pool = nn.Sequential(nn.AdaptiveMaxPool2d((1, 1)),
                        nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False),
                        nn.BatchNorm2d(mid_c), nn.ReLU())
        out_c = out_c if out_c is not None else mid_c
        self.out_conv = nn.Sequential(nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False),
                                    nn.BatchNorm2d(out_c), nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(mid_c*(2+len(dilations)), out_c, 1, bias=False)
        self._init_weight()

    def forward(self, x):
        x0 = self.global_pool(x)
        xs = [aspp(x) for aspp in self.aspps]
        x0 = F.interpolate(x0, size=xs[0].size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x0] + xs, dim=1)
        return self.out_conv(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

添加卷積注意力模塊(CBAM),這是一種用于前饋卷積神經(jīng)網(wǎng)絡的簡單而有效的注意力模塊扼菠。 給定一個中間特征圖摄杂,CBAM模塊會沿著兩個獨立的維度(通道和空間)依次推斷注意力圖,然后將注意力圖與輸入特征圖相乘以進行自適應特征優(yōu)化循榆。 由于CBAM是輕量級的通用模塊析恢,因此可以忽略的該模塊的開銷而將其無縫集成到任何CNN架構中,并且可以與基礎CNN一起進行端到端訓練秧饮。

注意力不僅要告訴我們重點關注哪里映挂,還要提高關注點的表示。 目標是通過使用注意機制來增加表現(xiàn)力盗尸,關注重要特征并抑制不必要的特征柑船。為了強調空間和通道這兩個維度上的有意義特征,依次應用通道和空間注意模塊振劳,來分別在通道和空間維度上學習關注什么椎组、在哪里關注。此外历恐,通過了解要強調或抑制的信息也有助于網(wǎng)絡內的信息流動寸癌。

CBAM 包含2個獨立的子模塊, 通道注意力模塊(Channel Attention Module弱贼,CAM) 和空間注意力模塊(Spartial Attention Module蒸苇,SAM) ,分別進行通道與空間上的 Attention 吮旅。

通道注意力模塊:通道維度不變溪烤,壓縮空間維度味咳。該模塊關注輸入圖片中有意義的信息(分類任務就關注因為什么分成了不同類別)。
圖解:將輸入的feature map經(jīng)過兩個并行的MaxPool層和AvgPool層檬嘀,將特征圖從C*H*W變?yōu)镃*1*1的大小槽驶,然后經(jīng)過Share MLP模塊,在該模塊中鸳兽,它先將通道數(shù)壓縮為原來的1/r(Reduction掂铐,減少率)倍,再擴張到原通道數(shù)揍异,經(jīng)過ReLU激活函數(shù)得到兩個激活后的結果全陨。將這兩個輸出結果進行逐元素相加,再通過一個sigmoid激活函數(shù)得到Channel Attention的輸出結果衷掷,再將這個輸出結果乘原圖辱姨,變回C*H*W的大小。

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, rotio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.sharedMLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
            nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.sharedMLP(self.avg_pool(x))
        maxout = self.sharedMLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

空間注意力模塊:空間維度不變戚嗅,壓縮通道維度雨涛。該模塊關注的是目標的位置信息。
圖解:將Channel Attention的輸出結果通過最大池化和平均池化得到兩個1*H*W的特征圖懦胞,然后經(jīng)過Concat操作對兩個特征圖進行拼接镜悉,通過7*7卷積變?yōu)?通道的特征圖(實驗證明7*7效果比3*3好),再經(jīng)過一個sigmoid得到Spatial Attention的特征圖医瘫,最后將輸出結果乘原圖變回C*H*W大小侣肄。

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3,7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1

        self.conv = nn.Conv2d(2,1,kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avgout, maxout], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)
class CBAM(nn.Module):
    def __init__(self, planes):
        super(cbam,self).__init__()
        self.ca = ChannelAttention(planes)# planes是feature map的通道個數(shù)
        self.sa = SpatialAttention()
     def forward(self, x):
        x = self.ca(x) * x  # 廣播機制
        x = self.sa(x) * x  # 廣播機制

損失和度量

在圖像分割任務中,經(jīng)常出現(xiàn)類別分布不均勻的情況醇份,例如:工業(yè)產(chǎn)品的瑕疵檢測稼锅、道路提取及病變區(qū)域提取等。我們可使用lovasz loss解決這個問題僚纷。
Lovasz loss基于子模損失(submodular losses)的凸Lovasz擴展矩距,對神經(jīng)網(wǎng)絡的mean IoU損失進行優(yōu)化。Lovasz loss根據(jù)分割目標的類別數(shù)量可分為兩種:lovasz hinge loss和lovasz softmax loss. 其中l(wèi)ovasz hinge loss適用于二分類問題怖竭,lovasz softmax loss適用于多分類問題锥债。
Jaccard index :


優(yōu)化的IOU loss:
其定義是離散的loss,不能直接求導痊臭,所以無法直接用來作為loss function哮肚。為了克服這個離散的問題,本文將其做了光滑的延拓(smooth extensions)广匙,從而可以使得其作為分割網(wǎng)絡的loss function允趟。變形為:
目前想要優(yōu)化的loss function,其自變量為網(wǎng)絡分割結果和label不匹配的集合鸦致。將其做光滑的延拓不是一件簡單的事情潮剪,更一般的說涣楷,對任意的離散函數(shù)找到其光滑的延拓很難。好在變化后的公式是submodular的抗碰,submodular的函數(shù)已經(jīng)有成熟數(shù)學工具可以將其做光滑延拓狮斗,而且延拓后的函數(shù)總是凸的,這樣就大大方便了優(yōu)化弧蝇。該數(shù)學工具即為lovasz extension
即轉成具有凸解形式:

代碼實現(xiàn)

  • 為什么用這么復雜情龄,看起來也不簡單的數(shù)學工具來對Jaccard loss進行smooth extension,直接像Dice loss那樣計算Jaccard loss不行嗎捍壤?

基于該想法的工作已經(jīng)在16年發(fā)表了出來Optimizing Intersection-Over-Union in Deep Neural Networks for Image Segmentation,雖然本文沒有與其進行比較鞍爱,但作者在github中說本文對Jaccard loss光滑延拓得到的loss要比Dice loss那樣簡單的光滑化(連續(xù)畫處理)效果好鹃觉。

  • Dice loss與IOU loss哪個用于網(wǎng)絡模型的訓練比較好?

都不太好睹逃。兩者都存在訓練過程不穩(wěn)定的問題盗扇,在和很小的情況下會得到較大的梯度,會影響正常的反向傳播沉填。一般情況下疗隶,使用兩者對應的損失函數(shù)的原因是分割的真實目的是最大化這兩個度量指標,而交叉熵是一種代理形式翼闹,利用了其在反向傳播中易于最大化優(yōu)化的特點斑鼻。
所以,正常情況下是使用交叉熵損失函數(shù)來訓練網(wǎng)絡模型猎荠,用Dice或IOU系數(shù)來衡量模型的性能坚弱。因為,交叉熵損失函數(shù)得到的交叉熵值關于logits的梯度計算形式類似:p-g(p是softmax的輸出結果关摇,g是ground truth)荒叶,這樣的關系式自然在求梯度的時候容易的多。而Dice系數(shù)的可微形式输虱,loss值為2pg/(p^2 + g^2)或2pg/(p+g),其關于p的梯度形式顯然是比較復雜的些楣,且在極端情況下(p,g的值都非常小時)計算得到的梯度值可能會非常大宪睹,進而會導致訓練不穩(wěn)定愁茁。

在本項目中采用了對稱的lovasz損失,不僅考慮預測的分割和提供的掩碼亭病,還要考慮逆預測和逆掩碼(否定情況的預測掩膜)埋市。

def symmetric_lovasz(outputs, targets):
    return 0.5*(lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1.0 - targets))

lovasz對分割的效果出類拔萃相比bce或者dice等loss可以提升一個檔次,但是有時的效果一般命贴,猜測是優(yōu)化不同的metric道宅,不同loss帶來的效果不同食听,也可能是數(shù)據(jù)帶來的問題。

模型推理

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, data):
        self.data = data
        if self.data.count != 3:
            subdatasets = self.data.subdatasets
            self.layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    self.layers.append(rasterio.open(subdataset))
        self.shape = self.data.shape
        self.mask_grid = make_grid(self.data.shape, window=WINDOW, min_overlap=MIN_OVERLAP)

        
    def __len__(self):
        return len(self.mask_grid)
        
    def __getitem__(self, idx):
        x1, x2, y1, y2 = self.mask_grid[idx]
        if self.data.count == 3:
            img = data.read([1,2,3], window=Window.from_slices((x1, x2), (y1, y2)))
            img = np.moveaxis(img, 0, -1)
        else:
            img = np.zeros((WINDOW, WINDOW, 3), dtype=np.uint8)
            for i, layer in enumerate(self.layers):
                img[:,:,i] = layer.read(window=Window.from_slices((x1, x2),(y1, y2)))

        img = cv2.resize(img, (NEW_SIZE, NEW_SIZE),interpolation = cv2.INTER_AREA)
        vetices = torch.tensor([x1, x2, y1, y2])
        return img2tensor((img/255.0 - mean)/std), vetices
def Make_prediction(img, tta = True):
    pred = None
    with torch.no_grad():
        for model in models:
            p_tta = None
            p = model(img)
            p = torch.sigmoid(p).detach()
            if p_tta is None:
                p_tta = p
            else:
                p_tta += p
            if tta:
                #x,y,xy flips as TTA
                flips = [[-1],[-2],[-2,-1]]
                for f in flips:
                    imgf = torch.flip(img, f)
                    p = model(imgf)
                    p = torch.flip(p, f)
                    p_tta += torch.sigmoid(p).detach()
                p_tta /= (1+len(flips))
            if pred is None:
                pred = p_tta
            else:
                pred += p_tta
        pred /= len(models)
    return pred
WINDOW=1024
MIN_OVERLAP=300
NEW_SIZE=256
NUM_CLASSES=1
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
names, predictions = [],[]

df_sample = pd.read_csv("../input/hubmap-kidney-segmentation/sample.csv")
# df_sample = df_sample.replace(np.nan, '', regex=True)
th = 0.4   
for idx, row in tqdm(df_sample.iterrows(),total=len(df_sample)):
    imageId = row['id']
    data = rasterio.open(os.path.join(DATA_PATH, imageId+'.tiff'), transform = identity, num_threads='all_cpus')
    preds = np.zeros(data.shape, dtype=np.uint8)
    dataset = HuBMAPDataset(data)
    dataloader = DataLoader(dataset, batch_size, num_workers=0, shuffle=False, pin_memory=True)
    for i, (img, vertices) in enumerate(dataloader):
        img = img.to(DEVICE)
        pred = Make_prediction(img)
        pred = pred.squeeze().cpu().numpy()
        vertices = vertices.numpy()
        for p, vert in zip(pred, vertices):
            x1, x2, y1, y2 = vert
            p = cv2.resize(p, (WINDOW, WINDOW))
            preds[x1:x2,y1:y2] +=  (p > th).astype(np.uint8)
    preds = (preds > th).astype(np.uint8)
    #convert to rle
    rle = rle_encode_less_memory(preds)
    names.append(imageId)
    predictions.append(rle)
    del preds, dataset, dataloader
    gc.collect()
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
  • 序言:七十年代末污茵,一起剝皮案震驚了整個濱河市樱报,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌泞当,老刑警劉巖迹蛤,帶你破解...
    沈念sama閱讀 219,270評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異襟士,居然都是意外死亡盗飒,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,489評論 3 395
  • 文/潘曉璐 我一進店門陋桂,熙熙樓的掌柜王于貴愁眉苦臉地迎上來逆趣,“玉大人,你說我怎么就攤上這事嗜历⌒” “怎么了?”我有些...
    開封第一講書人閱讀 165,630評論 0 356
  • 文/不壞的土叔 我叫張陵梨州,是天一觀的道長痕囱。 經(jīng)常有香客問我,道長暴匠,這世上最難降的妖魔是什么鞍恢? 我笑而不...
    開封第一講書人閱讀 58,906評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮每窖,結果婚禮上有序,老公的妹妹穿的比我還像新娘。我一直安慰自己岛请,他們只是感情好旭寿,可當我...
    茶點故事閱讀 67,928評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著崇败,像睡著了一般盅称。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上后室,一...
    開封第一講書人閱讀 51,718評論 1 305
  • 那天缩膝,我揣著相機與錄音,去河邊找鬼岸霹。 笑死疾层,一個胖子當著我的面吹牛,可吹牛的內容都是我干的贡避。 我是一名探鬼主播痛黎,決...
    沈念sama閱讀 40,442評論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼予弧,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了湖饱?” 一聲冷哼從身側響起掖蛤,我...
    開封第一講書人閱讀 39,345評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎井厌,沒想到半個月后蚓庭,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,802評論 1 317
  • 正文 獨居荒郊野嶺守林人離奇死亡仅仆,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 37,984評論 3 337
  • 正文 我和宋清朗相戀三年器赞,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片墓拜。...
    茶點故事閱讀 40,117評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡港柜,死狀恐怖,靈堂內的尸體忽然破棺而出撮弧,到底是詐尸還是另有隱情,我是刑警寧澤姚糊,帶...
    沈念sama閱讀 35,810評論 5 346
  • 正文 年R本政府宣布贿衍,位于F島的核電站,受9級特大地震影響救恨,放射性物質發(fā)生泄漏贸辈。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,462評論 3 331
  • 文/蒙蒙 一肠槽、第九天 我趴在偏房一處隱蔽的房頂上張望擎淤。 院中可真熱鬧,春花似錦秸仙、人聲如沸嘴拢。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,011評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽席吴。三九已至,卻和暖如春捞蛋,著一層夾襖步出監(jiān)牢的瞬間孝冒,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,139評論 1 272
  • 我被黑心中介騙來泰國打工拟杉, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留庄涡,地道東北人。 一個月前我還...
    沈念sama閱讀 48,377評論 3 373
  • 正文 我出身青樓搬设,卻偏偏與公主長得像穴店,于是被迫代替她去往敵國和親撕捍。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 45,060評論 2 355

推薦閱讀更多精彩內容