文章名稱
CounterFactual Regression with Importance Sampling Weights
核心要點(diǎn)
文章主要針對binary treatment的場景巾乳,能夠用來估計CATE(當(dāng)然也可以估計ATE)茵肃。作者基于CFR[1],提出利用上下文感知的重要性采樣來取代CFR的固定權(quán)重,來平衡selection bias。相比于BNN和CFR利用頻率統(tǒng)計得到的樣本權(quán)重,文章提出的方法能夠?qū)崿F(xiàn)selection bias的平衡,彌補(bǔ)IPM loss較小平衡能力不足的問題。CFR-IS采用兩階段交替學(xué)習(xí)班巩。首先,利用給定權(quán)重嘶炭,訓(xùn)練類似BNN和CFR的loss抱慌。隨后,通過最小化NLL得到更優(yōu)的權(quán)重旱物。
方法細(xì)節(jié)
問題引入
BNN和CFR主要利用IPM來平衡不同treatment下的分布差異遥缕,具體loss如下圖所示。但是由于這種平衡是建立在的聯(lián)合分布上的宵呛,
的影響可能會被忽略单匣,而且高維特征會導(dǎo)致有treatment引起的分布距離比較小,不能夠提供足夠的loss宝穗,來進(jìn)行selection bias的平衡户秤。
同時,BNN和CFR在構(gòu)建factual loss(估計樣本實(shí)際輸出)的時候逮矛,采用了頻率統(tǒng)計得到的權(quán)重鸡号,即圖中的
而經(jīng)過loss的改寫汞窗,發(fā)現(xiàn)這部分權(quán)重的目標(biāo)是平衡樣本不均(參見引用[1])姓赤,并不能起到balancing當(dāng)中的re-weigthing的作用。因此仲吏,總體作者認(rèn)為對selection bias的矯正是不充分的不铆。所以,提出利用重要性采樣的方法來學(xué)習(xí)樣本權(quán)重實(shí)現(xiàn)不同treatment下的covariates均衡(大家都是這條路裹唆,做法不同而已)誓斥。
具體做法
因此,作者把兩個不同的treatment下的分布许帐,看做是兩個不同分布的采樣劳坑。為了對齊兩個分布的學(xué)習(xí)效果,我們把counterfactual的covariates分布當(dāng)做是目標(biāo)分布
舞吭,把實(shí)際觀測到的樣本分布
當(dāng)做采樣分布
泡垃。例如析珊,當(dāng)我們處理
的數(shù)據(jù)是羡鸥,
的covariates分布就是采樣分布请契,而
是目標(biāo)分布辙诞。
當(dāng)控制住
因此衷旅,得到不同treatment下
為了能夠在觀測數(shù)據(jù)上也表現(xiàn)得好(也就是預(yù)測好factual)柿顶,作者在權(quán)重上加1,表示采樣分布和目標(biāo)分布是同一個操软。
但是嘁锯,我們發(fā)現(xiàn)直接估計這個weight不現(xiàn)實(shí),因?yàn)槭且烙嬕粋€隱向量在不同treatment下出現(xiàn)的概率的比值聂薪。無論是直接估計概率密度函數(shù)家乘,還是用高斯建模概率的密度函數(shù)要么計算量大,要么假設(shè)太強(qiáng)藏澳,不準(zhǔn)確仁锯。所以作者采用貝葉斯法則轉(zhuǎn)化了weight的估計方式,如下圖所示翔悠。其中业崖,
學(xué)習(xí)propensity的loss就是簡單的NLL双炕。作者采用交替優(yōu)化CFR loss和propensity loss的方法進(jìn)行學(xué)(也許可以一起學(xué)复罐,類似Dragnnet)。
具體的網(wǎng)絡(luò)結(jié)構(gòu)如圖所示雄家,
代碼實(shí)現(xiàn)
(留坑待填...)
心得體會
why IS work?
個人理解效诅,IS就是把眼分布的數(shù)據(jù)用來換到目標(biāo)分布來估計目標(biāo)結(jié)果。這里weight是用在factual loss的那個部分趟济,也就是說乱投,我們假設(shè)樣本可能來自counterfactual分布,在這種情況下還用觀測結(jié)果作為事實(shí)來代表counterfactual的值顷编,就需要用IS戚炫。并且IS之后,就可以把估計factual loss當(dāng)做是在估計counterfactual loss媳纬。
add 1 to weight
在權(quán)重上+1双肤,就把一個樣本分成了兩個。因?yàn)椋?img class="math-inline" src="https://math.jianshu.com/math?formula=(1%2Bw_%7Bi%7D)%20x%20%3D%20x%20%2B%20w_i%20x" alt="(1+w_{i}) x = x + w_i x" mathimg="1">钮惠。本質(zhì)是表示如果這個樣本實(shí)際就是從觀測分布來的茅糜,那么就不需要加權(quán),但需要被用來估計factual素挽。
文章引用
[1] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.