下面是具體的參數(shù):
1. pos_weight:
-
處理樣本不均衡問(wèn)題
torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)
- 其中* pos_weight (Tensor*, *optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
- pos_weight里是一個(gè)tensor列表,需要和標(biāo)簽個(gè)數(shù)相同,比如現(xiàn)在有一個(gè)多標(biāo)簽分類肄鸽,類別有200個(gè),那么 pos_weight 就是為每個(gè)類別賦予的權(quán)重值逼龟,長(zhǎng)度為200,官方給出的例子是:
target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones([64]) # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
- 如果現(xiàn)在是二分類追葡,只需要將正樣本loss的權(quán)重寫上即可腺律,比如我們有正負(fù)兩類樣本,正樣本數(shù)量為100個(gè)宜肉,負(fù)樣本為400個(gè)匀钧,我們想要對(duì)正負(fù)樣本的loss進(jìn)行加權(quán)處理,將正樣本的loss權(quán)重放大4倍谬返,通過(guò)這樣的方式緩解樣本不均衡問(wèn)題:
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
-- pos_weight (Tensor, optional): a weight of positive examples.
--Must be a vector with length equal to the number of classes.