


一般來(lái)說(shuō),與從頭開(kāi)始訓(xùn)練的基線模型相比 (如果不使用預(yù)先訓(xùn)練過(guò)的權(quán)重)鞍泉,微調(diào)后的模型可以獲得更好的整體性能皱埠。這一觀察結(jié)果不僅驗(yàn)證了預(yù)訓(xùn)練的權(quán)重的有效性,而且驗(yàn)證了不同的預(yù)訓(xùn)練的權(quán)重的影響是可變的咖驮。




一些特定類別的醫(yī)學(xué)圖像( MRI和病理圖像)在空間尺寸上往往非常大,在定量上缺乏足夠的訓(xùn)練樣本睦刃。因此砚嘴,直接使用這些圖像來(lái)訓(xùn)練模型是不切實(shí)際的。相反眯勾,人們通常在更小的空間尺度上將整個(gè)圖像重新采樣成不同的圖像補(bǔ)丁枣宫,這樣模型就可以用更少的GPU內(nèi)存成本實(shí)現(xiàn),并且可以得到更好的訓(xùn)練吃环。直觀地看也颤,補(bǔ)丁大小是影響模型性能的最重要因素之一。一般而言郁轻,隨著補(bǔ)丁尺寸的增加翅娶,模型的性能增益逐漸增加。其可以通過(guò)RandomCrop實(shí)現(xiàn)好唯,集成到transform里面竭沫。

class RandomCrop(object):
    Crop randomly the image in a sample
    output_size (int): Desired output size

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        # pad the sample if necessary
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)

        (w, h, d) = image.shape
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}


  def test_single_case(self, image):
        w, h, d = image.shape
        tta = TTA(if_flip=self.opt.test['flip'], if_rot=self.opt.test['rotate'])
        patch_size = self.opt.model['input_size']
        stride_xy = patch_size[0]//2
        stride_z = patch_size[2]//2
        # if the size of image is less than patch_size, then padding it
        add_pad = False
        if w < patch_size[0]:
            w_pad = patch_size[0]-w
            add_pad = True
            w_pad = 0
        if h < patch_size[1]:
            h_pad = patch_size[1]-h
            add_pad = True
            h_pad = 0
        if d < patch_size[2]:
            d_pad = patch_size[2]-d
            add_pad = True
            d_pad = 0
        wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
        hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
        dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
        if add_pad:
            image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
        ww,hh,dd = image.shape

        sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
        sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
        sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
        # print("{}, {}, {}".format(sx, sy, sz))
        score_map = np.zeros((self.opt.model['num_class'], ) + image.shape).astype(np.float32)
        cnt = np.zeros(image.shape).astype(np.float32)

        for x in range(0, sx):
            xs = min(stride_xy*x, ww-patch_size[0])
            for y in range(0, sy):
                ys = min(stride_xy * y,hh-patch_size[1])
                for z in range(0, sz):
                    zs = min(stride_z * z, dd-patch_size[2])
                    test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                    # apply tta
                    test_patch_list = tta.img_list(test_patch)
                    y_list = []
                    for img in test_patch_list:
                        img = np.expand_dims(np.expand_dims(img,axis=0),axis=0).astype(np.float32)
                        img = torch.from_numpy(img).cuda()
                        if not self.opt.train['deeps']:
                            y = self.net(img)
                            y = self.net(img)[0]
                        y = F.softmax(y, dim=1)
                        y = y.cpu().detach().numpy()
                        y = np.squeeze(y)
                    y_list = tta.img_list_inverse(y_list)
                    y = np.mean(y_list, axis=0)
                    score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                    = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                    cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                    = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
        score_map = score_map/np.expand_dims(cnt,axis=0)
        label_map = np.argmax(score_map, axis = 0)
        if add_pad:
            label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
            score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
        return label_map, score_map



class SelectedCrop(object):
    def __init__(self, output_size, oversample_foreground_percent=0.3):
        self.output_size = output_size
        self.percent = oversample_foreground_percent

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if np.random.random() < self.percent:
            pixels = np.argwhere(label != 0)
            if len(pixels) == 0:
                return RandomCrop(self.output_size)(sample)
                selected_pixel = pixels[np.random.choice(len(pixels))]
                pw = self.output_size[0] // 2 + 1
                ph = self.output_size[1] // 2 + 1
                pd = self.output_size[2] // 2 + 1

                image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
                label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
                bbox_x_lb = selected_pixel[0]
                bbox_y_lb = selected_pixel[1]
                bbox_z_lb = selected_pixel[2]

                label = label[bbox_x_lb:bbox_x_lb + self.output_size[0], bbox_y_lb:bbox_y_lb + self.output_size[1], bbox_z_lb:bbox_z_lb + self.output_size[2]]
                image = image[bbox_x_lb:bbox_x_lb + self.output_size[0], bbox_y_lb:bbox_y_lb + self.output_size[1], bbox_z_lb:bbox_z_lb + self.output_size[2]]
                return {'image': image, 'label': label}

            return RandomCrop(self.output_size)(sample)


ReSam策略仔役,通過(guò)機(jī)器學(xué)習(xí)模型來(lái)提高所使用數(shù)據(jù)集的表征能力。由于可用的樣本能力有時(shí)是有限的和異質(zhì)的是己,因此可以通過(guò)隨機(jī)/非隨機(jī)ReSam策略獲得更好的子樣本數(shù)據(jù)集骂因。在其實(shí)現(xiàn)中,ReSam主要包括四個(gè)步驟:1)間隔插值赃泡;2)窗口變換寒波;3)掩模有效范圍的獲取,4)子圖像的生成升熊《硭福基于重組后的子樣本數(shù)據(jù)集, 可以訓(xùn)練一個(gè)性能更好的識(shí)別模型级野。

在醫(yī)學(xué)圖像中页屠,重采樣是指將醫(yī)療圖像中大小不同的體素歸一化到相同的大小。體素是體積元素(Volume Pixel)的簡(jiǎn)稱蓖柔,一張3D醫(yī)學(xué)圖像可以看成是由若干個(gè)體素構(gòu)成的辰企,體素是一張3D醫(yī)療圖像在空間上的最小單元。
重采樣過(guò)程:Spacing(0.7422, 0.7422, 8.0)表示的是原始圖像體素的大小况鸣,也可以將Spacing想象成大小為(0.7422, 0.7422, 8.0)的長(zhǎng)方體牢贸。而原始圖像的Size為 (512, 512, 22),表示的是原始在X軸镐捧,Y軸潜索,Z軸中體素的個(gè)數(shù)。原始圖像的大小對(duì)應(yīng)的Spacing既可以得到真實(shí)3D圖像大卸础(512*0.7422竹习,512*0.7422,8*22 )列牺,在圖像重采樣只是修改體素的大小整陌,而真實(shí)3D圖像大小是保持不變的,因此假設(shè)我們將Spacing修改成(1.4844, 1.4844, 2.75)的時(shí)候,則修改之后其對(duì)應(yīng)的size應(yīng)該為((512*0.7422)/ 1.4844泌辫,(512*0.7422)/ 1.4844并炮,(22*8)/ 2.75)即(256, 256, 64)。

def resample(self, image, label, spacing, new_spacing=[1,1,1]):
      spacing, new_spacing = np.array(spacing), np.array(new_spacing)
      resize_factor = spacing / new_spacing
      old_shape = np.array(image.shape)
      new_real_shape = old_shape * resize_factor
      new_shape = np.round(new_real_shape)
      real_resize_factor = new_shape / old_shape
      new_spacing = spacing / real_resize_factor

      # image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
      image = np.moveaxis(image, 0, -1) # x, y, z
      label = np.moveaxis(label, 0, -1) # x, y, z
      image = np.expand_dims(image, axis=0) # 1, x, y, z
      label = np.expand_dims(label, axis=0) # 1, x, y, z

      # target_size = tuple(new_shape.transpose(1,2,0).astype(int)) # x, y, z
      target_size = tuple(np.array((new_shape[1], new_shape[2], new_shape[0])).astype(int)) # x, y, z
      out_img, out_seg = augment_resize(sample_data=image, sample_seg=label, target_size=target_size)
      out_img, out_seg = out_img[0], out_seg[0] # x,y,z
      return out_img, out_seg



  • CT:通過(guò)統(tǒng)計(jì)整個(gè)數(shù)據(jù)集中mask內(nèi)像素的HU值范圍,clip出[0.05荤西,99.5]百分比范圍的HU值范圍澜搅,然后使用z-score方法進(jìn)行歸一化;
  • MR:對(duì)每個(gè)患者數(shù)據(jù)單獨(dú)執(zhí)行z-score歸一化邪锌。如果crop導(dǎo)致數(shù)據(jù)集的平均尺寸減小到1/4甚至更小勉躺,則只在mask內(nèi)執(zhí)行標(biāo)準(zhǔn)化,mask設(shè)置為0觅丰。
def _get_voxels_in_foreground(self,voxels,label):
    mask = label> 0
    # image = list(voxels[mask][::10]) # no need to take every voxel
    image = list(voxels[mask])
    median = np.median(image)
    mean = np.mean(image)
    sd = np.std(image)
    percentile_99_5 = np.percentile(image, 99.5)
    percentile_00_5 = np.percentile(image, 00.5)
    return percentile_99_5,percentile_00_5, median,mean,sd

def do_preprocessing(self, minimun=0, maxmun=0, new_spacing=(3.22, 1.62, 1.62)):
    self.data_info = pickle.load(open(join(self.out_base_raw, 'dataset_pro.pkl'), 'rb'))
    for i in range(len(self.data_info['patient_names'])):
        print(f"Preprocessing {i}/{len(self.data_info['patient_names'])}")
        # voxels = self.images[i]
        # label = self.labels[i]
        voxels = np.load(join(self.out_base_raw, "imagesTr", self.data_info['patient_names'][i] + "_image.npy"))
        label = np.load(join(self.out_base_raw, "imagesTr", self.data_info['patient_names'][i] + "_label.npy"))
        if minimun:
            lower_bound = minimun
            upper_bound = maxmun
            upper_bound, lower_bound, median, mean_before, sd_before = self._get_voxels_in_foreground(voxels, label)
        voxels = np.clip(voxels, lower_bound, upper_bound)
        ### Convert to [0, 1]
        voxels = (voxels - voxels.min()) / (voxels.max() - voxels.min())

        # resample to isotropic voxel size
        spacing = self.data_info['dataset_properties'][self.data_info['patient_names'][i]]['spacing']
        spacing = (spacing[2], spacing[0], spacing[1])

        voxels, label = self.resample(voxels, label, spacing, new_spacing)
        np.save(join(self.out_base_preprocess, self.data_info['patient_names'][i] + "_image.npy"),
        np.save(join(self.out_base_preprocess, self.data_info['patient_names'][i] + "_label.npy"), label)
    save_pickle(self.data_info, join(self.out_base_preprocess, 'dataset_pro.pkl'))
    with open(self.out_base_preprocess + '/all.txt', 'w') as f:
        for train_patient in self.data_info['patient_names']:




包含兩種類型的數(shù)據(jù)增強(qiáng)懦底,分別為 GTAug-A (像素級(jí)變換)和GTAug-B(空間級(jí)變換):GTAug-A中包括隨機(jī)亮度對(duì)比唇牧、隨機(jī)噪聲、隨機(jī)伽馬和CLAHE聚唐,GTAug-B中包括位移丐重、尺度、旋轉(zhuǎn)杆查、水平翻轉(zhuǎn)和垂直翻轉(zhuǎn)等扮惦。

class RandomScale(object):
    def __init__(self, scale_factor):
        self.scale_factor = scale_factor
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        scale = np.random.uniform(self.scale_factor[0], self.scale_factor[1])
        new_shape = (int(image.shape[0] * scale), int(image.shape[1] * scale), int(image.shape[2] * scale))
        image = np.expand_dims(image, axis=0)
        label = np.expand_dims(label, axis=0)
        image, label = augment_resize(image, label, new_shape)
        image = np.squeeze(image)
        label = np.squeeze(label)
        return {'image': image, 'label': label}

class RandomRotation(object):
    Crop randomly flip the dataset in a sample
    output_size (int): Desired output size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.rot90(image, k)
        label = np.rot90(label, k)
        return {'image': image, 'label': label}

class RandomMirroring(object):
    def __init__(self, axes=(0, 1, 2)):
        self.axes = axes
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if 0 in self.axes  and np.random.uniform() < 0.5:
            image[:] = image[::-1]
            label[:] = label[::-1]
        if 1 in self.axes  and np.random.uniform() < 0.5:
            image[:, :] = image[:,::-1]
            label[:, :] = label[:,::-1]
        if 2 in self.axes  and np.random.uniform() < 0.5:
            image[:, :, :] = image[:, :, ::-1]
            label[:, :, :] = label[:, :, ::-1]
        return {'image': image, 'label': label}

class RandomNoise(object):
    def __init__(self, mu=0, sigma=0.1):
        self.mu = mu
        self.sigma = sigma

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma)
        noise = noise + self.mu
        image = image + noise
        return {'image': image, 'label': label}

class GammaAdjust(object):
    def __init__(self, gamma_range=(0.5, 2), epsilon=1e-7,retain_stats = False):
        self.gamma_range = gamma_range
        self.epsilon = epsilon
        self.retain_stats = retain_stats

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if self.retain_stats:
            mn = image.mean()
            sd = image.std()
        if np.random.random() < 0.5 and self.gamma_range[0] < 1:
            gamma = np.random.uniform(self.gamma_range[0], 1)
            gamma = np.random.uniform(max(self.gamma_range[0], 1), self.gamma_range[1])
        minm = image.min()
        rnge = image.max() - minm
        image = np.power(((image - minm) / float(rnge + self.epsilon)), gamma) * rnge + minm
        if self.retain_stats:
            image = image - image.mean() + mn
            image = image / (image.std() + 1e-8) * sd
        return {'image': image, 'label': label}




模型實(shí)現(xiàn)技巧對(duì)于醫(yī)學(xué)分割模型至關(guān)重要。 三類常用的實(shí)現(xiàn)技巧:深度監(jiān)督(DeepS)购笆;類平衡損失(CBL)粗悯,其中包括四個(gè)損失函數(shù)(CBL_{Dice}CBL_{Focal}同欠,CBL_{Tvers}样傍,CBL_{WCE})和實(shí)例規(guī)范化(IntNorm)。


DeepS是DSN中提出的一種輔助學(xué)習(xí)技巧铺遂,通過(guò)在一些中間隱藏層上以直接或間接的方式添加一個(gè)輔助分類器或分割器來(lái)實(shí)現(xiàn)監(jiān)督主干網(wǎng)絡(luò)的衫哥。它可用于解決訓(xùn)練梯度消失或收斂速度較慢的問(wèn)題。對(duì)于圖像分割襟锐,這個(gè)技巧通常通過(guò)添加圖像級(jí)分類損失來(lái)實(shí)現(xiàn)撤逢。 可以從最后三個(gè)解碼器層中提取特征圖,并使用1*1卷積層將掩膜投射到相同的通道大小中粮坞。然后蚊荣,通過(guò)雙線性插值將分割頭網(wǎng)絡(luò)不同層的輸出特征圖上采樣到與輸入圖像相同的空間大小。


CBL通常用于學(xué)習(xí)一般的類權(quán)重莫杈,每個(gè)類的權(quán)重只與對(duì)象類別相關(guān)妇押。與一些傳統(tǒng)的分割損失函數(shù)(交叉熵?fù)p失) 相比,在類不平衡數(shù)據(jù)集上CBL可以提高模型的表示能力姓迅。在所使用的數(shù)據(jù)集中敲霍,CBL引入了有效樣本的數(shù)量來(lái)表示所選數(shù)據(jù)集的期望體量表示,并通過(guò)有效樣本的數(shù)量而不是原始樣本的數(shù)量來(lái)加權(quán)不同的類丁存。四種常用的醫(yī)學(xué)圖像領(lǐng)域的CBL損失函數(shù)肩杈,包括骰子損失(CBL_{Dice})、焦點(diǎn)損失(CBL_{Focal}) 解寝,Tversky損失(CBL_{Tvers})和加權(quán)交叉熵?fù)p失(CBL_{WCE})扩然。

class CELoss(nn.Module):
    def __init__(self, weight=None, reduction='mean'):
        self.weight = weight
        self.reduction = reduction

    def __call__(self, y_pred, y_true):
        y_true = y_true.long()
        if self.weight is not None:
            self.weight = self.weight.to(y_pred.device)
        if len(y_true.shape) == 5:
            y_true = y_true[:, 0, ...]
        loss = nn.CrossEntropyLoss(weight=self.weight, reduction=self.reduction)
        return loss(y_pred, y_true)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-8):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        # first convert y_true to one-hot format
        axis = identify_axis(y_pred.shape)
        y_pred = nn.Softmax(dim=1)(y_pred)
        tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis)
        intersection = 2 * tp + self.smooth
        union = 2 * tp + fp + fn + self.smooth
        dice = 1 - (intersection / union)
        return dice.mean()

# taken from https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions/focal_loss.py
class FocalLoss(nn.Module):
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.

    def __init__(self, apply_nonlin=None, alpha=0.25, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
        super(FocalLoss, self).__init__()
        self.apply_nonlin = apply_nonlin
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        if self.apply_nonlin is not None:
            logit = self.apply_nonlin(logit)
        num_class = logit.shape[1]

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

            raise TypeError('Not support alpha type')

        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
            loss = loss.sum()
        return loss

class TverskyLoss(nn.Module):
    def __init__(self, alpha=0.3, beta=0.7, eps=1e-7):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred, y_true):
        axis = identify_axis(y_pred.shape)
        y_pred = nn.Softmax(dim=1)(y_pred)
        y_true = to_onehot(y_pred, y_true)
        y_pred = torch.clamp(y_pred, self.eps, 1. - self.eps)
        tp, fp, fn, _ = get_tp_fp_fn_tn(y_pred, y_true, axis)
        tversky = (tp + self.eps) / (tp + self.eps + self.alpha * fn + self.beta * fp)
        return (y_pred.shape[1] - tversky.sum()) / y_pred.shape[1]

def to_onehot(y_pred, y_true):
    shp_x = y_pred.shape
    shp_y = y_true.shape
    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            y_true = y_true.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(y_pred.shape, y_true.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = y_true 
            y_true = y_true.long()
            y_onehot = torch.zeros(shp_x, device=y_pred.device)
            y_onehot.scatter_(1, y_true, 1)
    return y_onehot

def get_tp_fp_fn_tn(net_output, gt, axes=None, square=False):
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    y_onehot = to_onehot(net_output, gt)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot
    tn = (1 - net_output) * (1 - y_onehot)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2
        tn = tn ** 2

    if len(axes) > 0:
        tp = sum_tensor(tp, axes, keepdim=False)
        fp = sum_tensor(fp, axes, keepdim=False)
        fn = sum_tensor(fn, axes, keepdim=False)
        tn = sum_tensor(tn, axes, keepdim=False)

    return tp, fp, fn, tn

def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def identify_axis(shape):
    Helper function to enable loss function to be flexibly used for 
    both 2D or 3D image segmentation - source: https://github.com/frankkramer-lab/MIScnn
    # Three dimensional
    if len(shape) == 5 : return [2,3,4]
    # Two dimensional
    elif len(shape) == 4 : return [2,3]
    # Exception - Unknown
    else : raise ValueError('Metric: Shape of tensor is neither 2D or 3D.')


OHEM主要思想是,根據(jù)輸入樣本的損失進(jìn)行篩選聋伦,篩選出hard example夫偶,表示對(duì)分類和檢測(cè)影響較大的樣本,然后將篩選得到的這些樣本應(yīng)用在隨機(jī)梯度下降中訓(xùn)練觉增。在實(shí)際操作中是將原來(lái)的一個(gè)ROI Network擴(kuò)充為兩個(gè)ROI Network兵拢,這兩個(gè)ROI Network共享參數(shù)。其中前面一個(gè)ROI Network只有前向操作逾礁,主要用于計(jì)算損失说铃;后面一個(gè)ROI Network包括前向和后向操作,以hard example作為輸入,計(jì)算損失并回傳梯度腻扇。這種算法的優(yōu)點(diǎn)在于债热,對(duì)于數(shù)據(jù)的類別不平衡問(wèn)題不需要采用設(shè)置正負(fù)樣本比例的方式來(lái)解決,且隨著數(shù)據(jù)集的增大幼苛,算法的提升更加明顯窒篱。

class OHEMLoss(nn.CrossEntropyLoss):
    Network has to have NO LINEARITY!
    def __init__(self, weight=None, ignore_index=-100, k=0.7):
        super(OHEMLoss, self).__init__()
        self.k = k
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, y_pred, y_true):
        res = CELoss(reduction='none')(y_pred, y_true)
        num_voxels = np.prod(res.shape, dtype=np.int64)
        res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k), sorted=False)
        return res.mean()




兩種常用的推理技巧 低滩,即測(cè)試時(shí)間增強(qiáng)(TTA)和模型集成召夹。這兩個(gè)技巧的實(shí)現(xiàn)細(xì) 如下:TTA是目前模型推理階段流行的數(shù)據(jù)增強(qiáng)機(jī)制。TTA不需要訓(xùn)練就可以用來(lái)提高識(shí)別性能恕沫,因此它有潛力成為一種即插即用的產(chǎn)品监憎。同時(shí),可以提高模型校準(zhǔn)能力婶溯,有利于視覺(jué)任務(wù)鲸阔。從三個(gè)方面遵循與相同的圖像增強(qiáng)策略: 1)在基線模型上實(shí)施TTA策略(TTA_{baseline});2)TTA_{GTAug-A}迄委;3)TTA_{GTAug-B}褐筛。集成模型集成策略旨在統(tǒng)一多個(gè)訓(xùn)練模型,基于一定的集成制在測(cè)試集上實(shí)現(xiàn)多模型融合結(jié)果叙身,使最終結(jié)果能夠從每個(gè)模型中學(xué)習(xí)渔扎,提高整體泛化能力。常用的模型集成 方法有投票信轿、平均晃痴、疊加和非交叉疊加(混合)。

class TTA():
    def __init__(self, if_tta):
        # for ISIC, the shape is (b, c, h, w)
        # for Kit, the shape is (x, y, z)
        self.if_tta = if_tta

    def img_list(self, img):
        out = []
        if not self.if_tta:
            return out
        # apply flip
        for i in range(3):
            out.append(np.flip(img, axis=i))
        # apply rotation
        for i in range(1, 4):
            out.append(np.rot90(img, k=i))
        return out
    def img_list_inverse(self, img_list):
        out = [img_list[0]]
        if not self.if_tta:
            return img_list
        # apply flip
        for i in range(3):
            out.append(np.flip(img_list[i+1], axis=i))
        if len(img_list) > 4:
            # apply rotation
            for i in range(3):
                out.append(np.rot90(img_list[i+4], k=-(i+1), axes=(1,2)))
        return out

class TTA_2d():
    def __init__(self, flip=False, rotate=False):
        self.flip = flip
        self.rotate = rotate

    def img_list(self, img):
        # for ISIC, the shape is torch.size(b, c, h, w)
        img = img.detach().cpu().numpy()
        out = []
        if self.flip:
            # apply flip
            for i in range(2,4):
                out.append(np.flip(img, axis=i))
        if self.rotate:
            # apply rotation
            for i in range(1, 4):
                out.append(np.rot90(img, k=i, axes=(2,3)))
        return out
    def img_list_inverse(self, img_list):
        # for ISIC, the shape is numpy(b, h, w)
        out = [img_list[0]]
        if self.flip:
            # apply flip
            for i in range(2):
                out.append(np.flip(img_list[i+1], axis=i+1))
        if self.rotate:
            # apply rotation
            for i in range(3):
                out.append(np.rot90(img_list[i+3], k=-(i+1), axes=(1,2)))
        return out


后處理操作的目的主要是通過(guò)不可學(xué)習(xí)的方法來(lái)提高模型性能财忽。例如分割結(jié)果可以通過(guò)聚合全局圖像信息來(lái)進(jìn)行細(xì)化倘核。醫(yī)學(xué)圖像分析領(lǐng)域的兩種常用的結(jié)果后處理方案:最大成分抑制(ABL-CS), 和去除小區(qū)域(RSA)即彪。 ABL-CS笤虫。ABL-CS的目的是基于有機(jī)體物理特性的知識(shí),去除分割結(jié)果中的一些錯(cuò)誤區(qū)域。例如琼蚯,對(duì)于心臟分割任務(wù)酬凳,我們都知道每個(gè)人只有一個(gè)心臟,所以如果在獲得的掩模中有小的分割區(qū)域遭庶,我們需要去除這個(gè)小區(qū)域宁仔。 RSA :設(shè)置一個(gè)像素級(jí)的閾值來(lái)刪除一些太小的實(shí)例掩碼。

import imp
import skimage.morphology as morph
import numpy as np
from scipy.ndimage import label

def abl(image: np.ndarray, for_which_classes: list, volume_per_voxel: float = None,
                                                   minimum_valid_object_size: dict = None):
    removes all but the largest connected component, individually for each class
    :param image:
    :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
    Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
    to use all foreground classes together)
    :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
    minimum_valid_object_size must match entries in for_which_classes
    if for_which_classes is None:
        for_which_classes = np.unique(image)
        for_which_classes = for_which_classes[for_which_classes > 0]
    assert 0 not in for_which_classes, "cannot remove background"

    if volume_per_voxel is None:
        volume_per_voxel = 1

    largest_removed = {}
    kept_size = {}
    for c in for_which_classes:
        if isinstance(c, (list, tuple)):
            c = tuple(c)  # otherwise it cant be used as key in the dict
            mask = np.zeros_like(image, dtype=bool)
            for cl in c:
                mask[image == cl] = True
            mask = image == c
        # get labelmap and number of objects
        lmap, num_objects = label(mask.astype(int))

        # collect object sizes
        object_sizes = {}
        for object_id in range(1, num_objects + 1):
            object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel

        largest_removed[c] = None
        kept_size[c] = None

        if num_objects > 0:
            # we always keep the largest object. We could also consider removing the largest object if it is smaller
            # than minimum_valid_object_size in the future but we don't do that now.
            maximum_size = max(object_sizes.values())
            kept_size[c] = maximum_size

            for object_id in range(1, num_objects + 1):
                # we only remove objects that are not the largest
                if object_sizes[object_id] != maximum_size:
                    # we only remove objects that are smaller than minimum_valid_object_size
                    remove = True
                    if minimum_valid_object_size is not None:
                        remove = object_sizes[object_id] < minimum_valid_object_size[c]
                    if remove:
                        image[(lmap == object_id) & mask] = 0
                        if largest_removed[c] is None:
                            largest_removed[c] = object_sizes[object_id]
                            largest_removed[c] = max(largest_removed[c], object_sizes[object_id])
    # return image, largest_removed, kept_size
    return image

def rsa(image: np.array, for_which_classes: list, volume_per_voxel: float = None, minimum_valid_object_size: dict = None):
    Remove samll objects, smaller than minimum_valid_object_size, individually for each class
    :param image:
    :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
    Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
    to use all foreground classes together)
    :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
    minimum_valid_object_size must match entries in for_which_classes
    if for_which_classes is None:
        for_which_classes = np.unique(image)
        for_which_classes = for_which_classes[for_which_classes > 0]
    assert 0 not in for_which_classes, "cannot remove background"

    if volume_per_voxel is None:
        volume_per_voxel = 1
    for c in for_which_classes:
        if isinstance(c, (list, tuple)):
            c = tuple(c)
            mask = np.zeros_like(image, dtype=bool)
            for cl in c:
                mask[image == cl] = True
            mask = image == c
        # get labelmap and number of objects
        lmap, num_objects = label(mask.astype(int))

        # collect object sizes
        object_sizes = {}
        for object_id in range(1, num_objects + 1):
            object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel

        if num_objects > 0:
            # removing the largest object if it is smaller than minimum_valid_object_size.
            for object_id in range(1, num_objects + 1):
                # we only remove objects that are smaller than minimum_valid_object_size
                if object_sizes[object_id] < minimum_valid_object_size[c]:
                    image[(lmap == object_id) & mask] = 0
    return image


# 通過(guò)連通成分分析峦睡,移除小區(qū)域
import SimpleITK as sitk
import os
import argparse
from pathlib import Path

def RemoveSmallConnectedCompont(sitk_maskimg, rate=0.5):
    two steps:
        step 1: Connected Component analysis: 將輸入圖像分成 N 個(gè)連通域
        step 2: 假如第 N 個(gè)連通域的體素小于最大連通域 * rate翎苫,則被移除
    :param sitk_maskimg: input binary image 使用 sitk.ReadImage(path, sitk.sitkUInt8) 讀取,
                        其中sitk.sitkUInt8必須注明榨了,否則使用 sitk.ConnectedComponent 報(bào)錯(cuò)
    :param rate: 移除率煎谍,默認(rèn)為0.5, 小于 1/2最大連通域體素的連通域被移除
    :return:  binary image龙屉, 移除了小連通域的圖像

    # step 1 Connected Component analysis
    cc = sitk.ConnectedComponent(sitk_maskimg)
    stats = sitk.LabelIntensityStatisticsImageFilter()
    stats.Execute(cc, sitk_maskimg)
    maxlabel = 0   # 獲取最大連通域的索引
    maxsize = 0    # 獲取最大連通域的體素大小

    # 遍歷每一個(gè)連通域呐粘, 獲取最大連通域的體素大小和索引
    for l in stats.GetLabels():  # stats.GetLabels()  (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
        size = stats.GetPhysicalSize(l)   # stats.GetPhysicalSize(5)=75  表示第5個(gè)連通域的體素有75個(gè)
        if maxsize < size:
            maxlabel = l
            maxsize = size

    # step 2 獲取每個(gè)連通域的大小,保留 size >= maxsize * rate 的連通域
    not_remove = []
    for l in stats.GetLabels():
        size = stats.GetPhysicalSize(l)
        if size >= maxsize * rate:

    labelmaskimage = sitk.GetArrayFromImage(cc)
    outmask = labelmaskimage.copy()
    outmask[labelmaskimage != maxlabel] = 0
    for i in range(len(not_remove)):
        outmask[labelmaskimage == not_remove[i]] = 1
  # 保存圖像
    outmask = outmask.astype('float32')

    out = sitk.GetImageFromArray(outmask)
    out.SetOrigin(sitk_maskimg.GetOrigin())   # 使 out 的層厚等信息同輸入一樣

    return out  # to save image: sitk.WriteImage(out, 'largecc.nii.gz')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="remove small connected domains")
    parser.add_argument('--input', type=str, default="./123.nii.gz")
    parser.add_argument("--output", type=str, default='./123.nii.gz')
    args = parser.parse_args()

    # for single image
    sitk_maskimg = sitk.ReadImage(args.input, sitk.sitkUInt8)
    out = RemoveSmallConnectedCompont(sitk_maskimg, rate=0.5)  # 可以設(shè)置不同的比率
    sitk.WriteImage(out, args.output)


from scipy.ndimage.morphology import binary_fill_holes
import numpy as np
from scipy import ndimage
import nibabel as nib
from skimage.measure import label
import matplotlib.pyplot as plt

def hole_filling(bw, hole_min, hole_max, fill_2d=True):
    bw = bw > 0
    if len(bw.shape) == 2:
        background_lab = label(~bw, connectivity=1)
        fill_out = np.copy(background_lab)
        component_sizes = np.bincount(background_lab.ravel())
        too_big = component_sizes > hole_max
        too_big_mask = too_big[background_lab]
        fill_out[too_big_mask] = 0
        too_small = component_sizes < hole_min
        too_small_mask = too_small[background_lab]
        fill_out[too_small_mask] = 0
    elif len(bw.shape) == 3:
        if fill_2d:
            fill_out = np.zeros_like(bw)
            for zz in range(bw.shape[1]):
                background_lab = label(~bw[:, zz, :], connectivity=1)   # 1表示4連通转捕, ~bw[zz, :, :]1變?yōu)?作岖, 0變?yōu)?
                # 標(biāo)記背景和孔洞, target區(qū)域標(biāo)記為0
                out = np.copy(background_lab)
                # plt.imshow(bw[:, :, 87])
                # plt.show()
                component_sizes = np.bincount(background_lab.ravel())
                # 求各個(gè)類別的個(gè)數(shù)
                too_big = component_sizes > hole_max
                too_big_mask = too_big[background_lab]

                out[too_big_mask] = 0
                too_small = component_sizes < hole_min
                too_small_mask = too_small[background_lab]
                out[too_small_mask] = 0
                # 大于最大孔洞和小于最小孔洞的都標(biāo)記為0五芝, 所以背景部分被標(biāo)記為0了痘儡。只剩下符合規(guī)則的孔洞
                fill_out[:, zz, :] = out
                # 只有符合規(guī)則的孔洞區(qū)域是1, 背景及target都是0
            background_lab = label(~bw, connectivity=1)
            fill_out = np.copy(background_lab)
            component_sizes = np.bincount(background_lab.ravel())
            too_big = component_sizes > hole_max
            too_big_mask = too_big[background_lab]
            fill_out[too_big_mask] = 0
            too_small = component_sizes < hole_min
            too_small_mask = too_small[background_lab]
            fill_out[too_small_mask] = 0

    return np.logical_or(bw, fill_out)  # 或運(yùn)算枢步,孔洞的地方是1近弟,原來(lái)target的地方也是1

bw: array, 待填補(bǔ)的數(shù)組
hole_min: 孔洞像素的個(gè)數(shù)最小值,一般為0
fill_2d:True:二維填充洋侨。False:三維填充 只有當(dāng)孔洞像素值個(gè)數(shù)在 [hole_min, hole_max] 才會(huì)被填補(bǔ)让蕾。


有關(guān)語(yǔ)義分割的奇技淫巧有哪些? - 知乎 (zhihu.com)
從Kaggle學(xué)語(yǔ)義分割技巧 - mdnice 墨滴
(42條消息) 分割網(wǎng)絡(luò)中的奇技淫巧_小白學(xué)視覺(jué)的博客-CSDN博客

Deep Learning for Medical Image Segmentation:Tricks, Challenges and Future Directions
hust-linyi/MedISeg (github.com)

