Loss Function in Metric Learning

General Idea:

For Classification Task:
Input the feature vector and the corresponding type,use the matrix to calculate the distance/simmilarity between the pairs. With different assumption or view, there're several different design for the loss function.
the code resource is from https://github.com/bnulihaixia/Deep_metric

The Main Scopes

  1. dynamic learning rate
  2. euclidean distance or similarity or js divergence
  3. whether consider the all samples or the select the hard samples
  4. how to evaluate the hard samples, absolute?relative? above/below mean? with sampling probability?
    if so, how to calculate the probability?
  5. whether use different weight for calculating loss
  6. whether use slice?BDW, like a batch normalization


    image.png

Details

BatchAll (almost same as A BatchAll, without the a_lr)

  • dynamic learning rate
            pos_logit = torch.sum(torch.exp(self.alpha * (1 - pos_pair)))
            neg_logit = torch.sum(torch.exp(self.alpha * (1 - neg_pair)))
            a_lr = 1 - (pos_logit / (pos_logit + neg_logit)).data[0]
            ......
            loss_ = a_lr*torch.sum(valid_triplets)
  • pos num: neg_num = 1:1 with the repeating operation
            pos_pair = pos_pair.repeat(num_neg_instances, 1)
            neg_pair = neg_pair.repeat((num_instances-1), 1).t()
  • self.margin
          triplet_mat = pos_pair - neg_pair + self.margin

A_hard_pair

  • dynamic learning rate
  • Focus on the hard pair, loss get value only when the pos pair exceeds the limit(too far) and the neg pair is within the margin(dis<1.1) (the absolute distance)
            pos_loss = torch.log(torch.sum(torch.exp(self.beta * (pos_pair - 0.8))))
            neg_loss = torch.log(torch.sum(torch.exp(self.beta * (1.1 - neg_pair))))

A_triplet/Triplet

  • dynamic learning rate
  • the pos distance exceed neg distance too much(the relative distance)
            triplet_mat = torch.log(torch.exp(self.beta*(pos_pair - neg_pair)) + 1)
            triplet_mask = triplet_mat > 0.65

BatchHard

select the max pos and min( both the hardest examples) by row in a batch;
(much more elegant for the neg_dist/pos_dist calculation)

        hard_pos = torch.max(pos_dist_mat, dim=0)[0]
        hard_neg = torch.min(neg_dist_mat, dim=0)[0]

BDWDistWeightNeighborloss( Slice+ DistWeightNeighborloss)

BinDevianceLoss(Branch)

  • use the similarity matrix, not the euclidean distance. The similarity score is high with the distance being short on the regularized case. and only select the neg pair of which similarity score is larger than the pos pair.
             neg_pair = torch.masked_select(neg_pair, neg_pair > pos_pair[0] - 0.05)
  • with constant margin to constrain
            pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin))))
            neg_loss = 0.04*torch.mean(torch.log(1 + torch.exp(50*(neg_pair - self.margin))))

CenterLoss

(ref:https://blog.csdn.net/u014380165/article/details/76946339)

image.png

image.png

image.png

the distance to the center for every feature, smaller distance in the cluster represents the better results
Steps:

  1. store the centers and inputs
  2. calculate the center dist
  3. the closet neighbour center as neg sample and the farthest input to its center as pos sample.(also for selecting the hard samples)
            dist_an.append(centers_dist[i][targets_ != target].min())
            center_diff = inputs_list[i] - centers[i]
            center_diff_norm = torch.cat([torch.norm(temp) for temp in center_diff])
            dist_ap.append(center_diff_norm.max())

CenterNCA

Center+ base(I can't tell the apparent feature of NCA from the view of code, maybe the "base" setting,but here, the pos or neg selecting is the individual sample with the center)

CenterPair

constant limit to select pos/neg samples

       loss = torch.mean(pos_dist.clamp(min=0.15) -
       torch.log(torch.sum(torch.exp(-neg_dist.clamp(max=0.6)), 0)))

ClusterNCA

lable is not from the initialization but the kmean clustering result(need to point out the number of cluster)

Contrastive Loss

take the samples with in same class as positive examples.


image.png

DistWeightLoss

  • select the pos samples according to the similarity probability(the weight)
  • select the hard neg pair
pos_pair = torch.sort(pos_pair)[0]
sampled_index = torch.multinomial(torch.exp(5*pos_pair), 1)
neg_pair = torch.masked_select(neg_pair, neg_pair > pos_min - 0.01)

DistWeightContrastiveLoss

Gaussian Probability to select, with constant limit.

DistanceMatchLoss

regard the neg/pos samples distribution as Gaussian Distribution, and select the sample use the Gaussian parameters

            neg_pair = neg_dist[i]
            neg_mean, neg_std = GaussDistribution(neg_pair)
            prob = torch.exp(torch.pow(neg_pair - neg_mean, 2) / (2 * torch.pow(neg_std, 2)))
            neg_index = torch.multinomial(prob, 3*num_instances, replacement=False)

and different weight to calculate the loss

               base = [0.95, 1.05, 1.12]
                muls = [4, 8, 16]
                pos_diff = [pos_pair[i] - base[i] for i in range(len(base))]
                pos_diff = torch.cat([1.0 / muls[i] *torch.log(1 + torch.exp(pos_diff[i])) for i in range(len(base))])

Gaussian LDA

different way to calculate the loss

           pos_logit = torch.sum(torch.exp(self.alpha*(1 - pos_neig)))
            neg_logit = torch.sum(torch.exp(self.alpha*(1 - neg_neig)))
            loss_ = -torch.log(pos_logit/(pos_logit + neg_logit))

GaussianMetric

in the code, there's nothing about Gaussian. The main idea is to select the neg/pos pair according to the mean samples.
(it seems that those are all about how to select negative samples and positive samples)

NCA Loss

To select pos/neg in Top-K Neighborhood.
base function: in case the float number operation happened illegally


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌皆尔,老刑警劉巖坦报,帶你破解...
    沈念sama閱讀 218,386評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件虱朵,死亡現(xiàn)場(chǎng)離奇詭異撑帖,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)黎烈,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,142評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人照棋,你說(shuō)我怎么就攤上這事资溃。” “怎么了烈炭?”我有些...
    開封第一講書人閱讀 164,704評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵溶锭,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我符隙,道長(zhǎng)趴捅,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,702評(píng)論 1 294
  • 正文 為了忘掉前任霹疫,我火速辦了婚禮拱绑,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘丽蝎。我一直安慰自己猎拨,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,716評(píng)論 6 392
  • 文/花漫 我一把揭開白布屠阻。 她就那樣靜靜地躺著红省,像睡著了一般。 火紅的嫁衣襯著肌膚如雪国觉。 梳的紋絲不亂的頭發(fā)上吧恃,一...
    開封第一講書人閱讀 51,573評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音蛉加,去河邊找鬼蚜枢。 笑死,一個(gè)胖子當(dāng)著我的面吹牛针饥,可吹牛的內(nèi)容都是我干的厂抽。 我是一名探鬼主播,決...
    沈念sama閱讀 40,314評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼丁眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼筷凤!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起苞七,我...
    開封第一講書人閱讀 39,230評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤藐守,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后蹂风,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體卢厂,經(jīng)...
    沈念sama閱讀 45,680評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,873評(píng)論 3 336
  • 正文 我和宋清朗相戀三年惠啄,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了慎恒。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片任内。...
    茶點(diǎn)故事閱讀 39,991評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖融柬,靈堂內(nèi)的尸體忽然破棺而出死嗦,到底是詐尸還是另有隱情,我是刑警寧澤粒氧,帶...
    沈念sama閱讀 35,706評(píng)論 5 346
  • 正文 年R本政府宣布越除,位于F島的核電站,受9級(jí)特大地震影響外盯,放射性物質(zhì)發(fā)生泄漏摘盆。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,329評(píng)論 3 330
  • 文/蒙蒙 一门怪、第九天 我趴在偏房一處隱蔽的房頂上張望骡澈。 院中可真熱鬧,春花似錦掷空、人聲如沸肋殴。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,910評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)护锤。三九已至,卻和暖如春酿傍,著一層夾襖步出監(jiān)牢的瞬間烙懦,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,038評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工赤炒, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留氯析,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,158評(píng)論 3 370
  • 正文 我出身青樓莺褒,卻偏偏與公主長(zhǎng)得像掩缓,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子遵岩,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,941評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容