(代碼實(shí)現(xiàn)的坑待填...降淮,日更太難了...)
文章名稱
Reducing Selection Bias in Counterfactual Reasoning for Individual Treatment Effects Estimation
核心要點(diǎn)
文章仍然關(guān)注binary treatment情境下的CATE估計(jì)翰铡。作者通過AE結(jié)合利用Pearson Correlation Coefficient的正則化,鼓勵(lì)模型對(duì)covariates進(jìn)行分解进陡,從而學(xué)習(xí)兩組不同的變量,一組只和outcome的treatment assignment相關(guān)(group A),另一組與selection bias和outcome prediction都相關(guān)(group BC),最終用group BC來同時(shí)平衡selection bias并預(yù)測(cè)outcome准验。
方法細(xì)節(jié)
問題引入
文章來自NeurIPS 2019 CausalML Workshop。相比于通過balancing with representation learning廷没,其實(shí)很多時(shí)候糊饱,我們把一些只影響potential outcome估計(jì)的covariates也當(dāng)做是confounder來做adjustment,導(dǎo)致在學(xué)習(xí)樣本平衡的時(shí)候存在噪聲颠黎,因果效應(yīng)的估計(jì)能力變差另锋。從因果圖的角度我們可以把confounder分為3類,第一類是只影響treatment assignment的狭归;第二類是confounder夭坪,不但印象treatment assignment,也影響outcome过椎;第三類則只影響outcome室梅,具體因果圖,如下圖所示疚宇。作者期望把第一類和第二亡鼠、三類covariates區(qū)分開,從而減少第一類covariates對(duì)potential outcome預(yù)測(cè)帶來的噪聲(因?yàn)槲覀儾魂P(guān)心是不是哪些雖然影響策略分配敷待,但完全不影響outcome的特征间涵,他們不會(huì)帶來偏差)。這個(gè)因果分解的思路最先出現(xiàn)在引用文章[1]里(后面會(huì)講榜揖,其實(shí)這個(gè)思路還不完整勾哩,后續(xù)會(huì)介紹更完善的covariates分解的相關(guān)文章),不同的是這篇文章把BC合并在了一起举哟,并且使用了不同的正則化方法 -- Pearson Correlation Coefficient思劳。
具體做法
實(shí)際的網(wǎng)絡(luò)結(jié)構(gòu)如圖所示。首先妨猩,通過一個(gè)AE潜叛,學(xué)習(xí)樣本表示,樣本表示由兩部分向量組成册赛。隨后钠导,利用學(xué)到的傳遞給outcome預(yù)測(cè)網(wǎng)絡(luò)震嫉,進(jìn)行不同counterfactual的預(yù)測(cè)森瘪。不知道有沒有同學(xué)有似曾相識(shí)的感覺。大概還是自監(jiān)督學(xué)習(xí)還沒有興起的時(shí)候(約2018-19年)票堵,曾經(jīng)流行用AE在大量的無標(biāo)簽樣本上進(jìn)行重構(gòu)損失的訓(xùn)練扼睬,然后利用訓(xùn)練的得到的隱向量,也就是這里的,來輔助做downstream的無監(jiān)督學(xué)習(xí)(表示學(xué)習(xí))窗宇。這種類型無監(jiān)督結(jié)合有監(jiān)督的方法在NLP措伐,CV都有使用,比如做文本分類军俊。后來還延伸出了很多方法侥加,諸如先做無監(jiān)督主題模型,學(xué)到的主題向量做文本分類(扯遠(yuǎn)了粪躬,回到正題...)担败。本質(zhì)是通過引入covariates de-correlation的輔助任務(wù),來消除selection bias镰官,只是這個(gè)輔助任務(wù)比其他的任務(wù)要聰明提前,因?yàn)椴坏m正了偏差,同時(shí)減少了噪聲泳唠,同時(shí)符合因果圖的理念(后邊會(huì)看到更精妙的狈网,比如去掉無意間引入的collidor)。
然而拓哺,僅憑這樣的網(wǎng)絡(luò),是不可能達(dá)到很好估計(jì)causal effect的效果的扇雕,不然不就沒有causal什么事兒了... 回想拓售,causal inference的兩個(gè)主要問題,1)missing counterfactual镶奉;2)selection bias础淤。這兩個(gè)問題還是需要通過loss function來解決。方法的整體loss如下哨苛,其中鸽凶,是無監(jiān)督表示學(xué)習(xí)的重構(gòu)損失;沒啥好說的建峭,是factual的估計(jì)損失(也就是觀測(cè)數(shù)據(jù)預(yù)測(cè)的準(zhǔn)不準(zhǔn))玻侥;是分布距離損失,用來度量不同treatment下covariates分布的差異性亿蒸,這個(gè)在之前介紹BNN的那篇完章里有些(理論證明的坑還沒有填上...凑兰,容證明再飛一會(huì)兒...);而就是文章的核心要點(diǎn)Pearson Correlation Coefficient边锁。
重構(gòu)損失姑食,是標(biāo)準(zhǔn)的損失,度量covariates的重構(gòu)能力茅坛,保證AE能夠充分學(xué)習(xí)(這里也許可以采用其他的AE音半,當(dāng)然已經(jīng)有用VAE做的了)。
預(yù)估損失,是BNN中提到的加權(quán)損失曹鸠。
分布差異損失煌茬,也是BNN中的Integral Probability Metric Loss。
de-correlation損失函數(shù)彻桃,是利用兩個(gè)不同向量組(A和BC)的皮爾遜相關(guān)系數(shù)作為損失函數(shù)坛善,當(dāng)這個(gè)損失達(dá)到最小的時(shí)候,兩個(gè)向量組線性無關(guān)邻眷。其中浑吟,指的是向量中的第個(gè)元素。是指第個(gè)樣本的隱向量表示耗溜,是所有樣本的平均组力,其他同理。
代碼實(shí)現(xiàn)
文章偽代碼參見下圖(實(shí)際代碼的坑后續(xù)再填...)抖拴。
心得體會(huì)
unsupervised assassinated supervised learning
文章用到的類似無監(jiān)督輔助有監(jiān)督學(xué)習(xí)的思路燎字,來幫助更準(zhǔn)確的估計(jì)potential outcome。本質(zhì)是尋找了更多的內(nèi)在信息或結(jié)構(gòu)阿宅,來引導(dǎo)potential outcome不要走偏(消除selection bias)候衍。這個(gè)和自監(jiān)督中尋找相關(guān)性的思路很吻合,也許自監(jiān)督與causal inference結(jié)合的方法已經(jīng)在路上了洒放。
linear independent
文章雖然通過PCC讓兩個(gè)向量組A和BC線性無關(guān)蛉鹿,但是在現(xiàn)實(shí)世界里covariates之間的非線性關(guān)系是存在的,也是神經(jīng)網(wǎng)絡(luò)的優(yōu)勢(shì)之一往湿。所以妖异,這種損失的de-correlation性能可能比較有限。
文章引用
[1] Negar Hassanpour and Russell Greiner. Counterfactual regression with importance sampling weights. In Proceedings of the Twenty-Eighth International Joint Conference on Artificial Intelligence, IJCAI-19, pages 5880–5887, 7 2019.
[2] Shalit, U., Johansson, F.D., & Sontag, D. (2017). Estimating individual treatment effect: generalization bounds and algorithms. ICML.