Triplet Loss是深度學(xué)習(xí)中的一種損失函數(shù)箱季,用于訓(xùn)練差異性較小的樣本裸删,如人臉等, Feed數(shù)據(jù)包括錨(Anchor)示例换帜、正(Positive)示例、負(fù)(Negative)示例鹤啡,通過優(yōu)化錨示例與正示例的距離小于錨示例與負(fù)示例的距離惯驼,實(shí)現(xiàn)樣本的相似性計(jì)算。
數(shù)據(jù)集:MNIST
目標(biāo):通過Triplet Loss訓(xùn)練模型递瑰,實(shí)現(xiàn)手寫圖像的相似性計(jì)算祟牲。
工程:https://github.com/SpikeKing/triplet-loss-mnist
模型
Triplet Loss的核心是錨示例、正示例抖部、負(fù)示例共享模型说贝,通過模型,將錨示例與正示例聚類慎颗,遠(yuǎn)離負(fù)示例乡恕。
Triplet Loss Model的結(jié)構(gòu)如下:
- 輸入:三個輸入,即錨示例俯萎、正示例傲宜、負(fù)示例,不同示例的結(jié)構(gòu)相同夫啊;
- 模型:一個共享模型函卒,支持替換為任意網(wǎng)絡(luò)結(jié)構(gòu);
- 輸出:一個輸出撇眯,即三個模型輸出的拼接报嵌。
Shared Model選擇常用的卷積模型,輸出為全連接的128維數(shù)據(jù):
Triplet Loss損失函數(shù)的計(jì)算公式如下:
訓(xùn)練
模型參數(shù):
- batch_size:32
- epochs:2
超參數(shù):
- 邊界Margin的值設(shè)置為
1
熊榛。
算法收斂較好锚国,Loss線性下降:
TF Graph:
驗(yàn)證
算法效率(TPS): 每秒48163次 (0.0207625 ms)
MNIST驗(yàn)證集的效果:
[INFO] trainer - clz 0
[INFO] trainer - distance - min: -15.4567, max: 1.98611, avg: -6.50481
[INFO] acc: 0.996632996633
[INFO] trainer - clz 1
[INFO] trainer - distance - min: -13.09, max: 3.43779, avg: -6.66867
[INFO] acc: 0.99214365881
[INFO] trainer - clz 2
[INFO] trainer - distance - min: -14.2524, max: 2.49437, avg: -5.60508
[INFO] acc: 0.991021324355
[INFO] trainer - clz 3
[INFO] trainer - distance - min: -16.6555, max: 1.21776, avg: -6.32161
[INFO] acc: 0.995510662177
[INFO] trainer - clz 4
[INFO] trainer - distance - min: -14.193, max: 1.65427, avg: -5.90896
[INFO] acc: 0.991021324355
[INFO] trainer - clz 5
[INFO] trainer - distance - min: -14.1007, max: 2.01843, avg: -6.36086
[INFO] acc: 0.994388327722
[INFO] trainer - clz 6
[INFO] trainer - distance - min: -16.8953, max: 2.84421, avg: -8.43978
[INFO] acc: 0.995510662177
[INFO] trainer - clz 7
[INFO] trainer - distance - min: -16.6177, max: 3.49675, avg: -5.99822
[INFO] acc: 0.989898989899
[INFO] trainer - clz 8
[INFO] trainer - distance - min: -14.937, max: 3.38141, avg: -5.4424
[INFO] acc: 0.979797979798
[INFO] trainer - clz 9
[INFO] trainer - distance - min: -16.9519, max: 2.39112, avg: -5.93581
[INFO] acc: 0.985409652076
測試的MNIST分布:
輸出的Triplet Loss MNIST分布:
本例僅僅使用2個Epoch,也沒有特殊設(shè)置超參来候,實(shí)際效果仍有提升空間跷叉。
歡迎Follow我的GitHub:https://github.com/SpikeKing
By C. L. Wang
That's all! Enjoy it!