源代碼為清華大學的KB2E(https://github.com/thunlp/KB2E)彰檬,但沒有注釋。本文加入在下的注釋
本文代碼具體源文件https://github.com/thunlp/KB2E/blob/master/TransE/Train_TransE.cpp
#include<iostream>
#include<cstring>
#include<cstdio>
#include<map>
#include<vector>
#include<string>
#include<ctime>
#include<cmath>
#include<cstdlib>
using namespace std;
#define pi 3.1415926535897932384626433832795
bool L1_flag = 1;//L1范數(shù),0表示L2
string version;
char buf[100000], buf1[100000];
int relation_num, entity_num;
map<string, int> relation2id, entity2id;
map<int, string> id2entity, id2relation;
map<int, map<int, int> > left_entity, right_entity;
//left_entity:在此relation下頭實體對應(yīng)的尾實體的個數(shù)轧铁,3個int分別表示relation_id,headentity_id,個數(shù)
//right_entity:在此relation下尾實體對應(yīng)的頭實體的個數(shù)错览,3個int分別表示relation_id,tailentity_id,個數(shù)
//主要是計算采樣概率p
map<int, double> left_num, right_num;//int表示relaitonid
//leftnum:平均每個頭實體對應(yīng)多少個尾實體,transH的tph
//rightnum:平均每個尾實體對應(yīng)多少頭實體,transh的hpt
//normal distribution
double rand(double min, double max)
{//產(chǎn)生一個[min,max)之間的隨機小數(shù)
return min + (max - min)*rand() / (RAND_MAX + 1.0);
}
double normal(double x, double miu, double sigma)
{//高斯分布概率密度
return 1.0 / sqrt(2 * pi) / sigma*exp(-1 * (x - miu)*(x - miu) / (2 * sigma*sigma));
}
double randn(double miu, double sigma, double min, double max)
{// 產(chǎn)生正態(tài)分布的隨機數(shù)
double x, y, dScope;
do{
x = rand(min, max);
y = normal(x, miu, sigma);
dScope = rand(0.0, normal(miu, miu, sigma));
} while (dScope>y);
return x;
}
double sqr(double x)
{
return x*x;
}
double vec_len(vector<double>&a)
{//返回a的模
double res = 0;
for (int i = 0; i < a.size(); i++)
res = res + a[i] * a[i];
res = sqrt(res);
return res;
}
class Train{
public:
map<pair<int, int>, map<int, int>> ok;//4個int分別表示headID纲辽,relationID啦吧,tailid您觉,狀態(tài)
void add(int headid, int tailid, int relationid)
{
fb_h.push_back(headid);
fb_t.push_back(tailid);
fb_r.push_back(relationid);
ok[make_pair(headid, relationid)][tailid] = 1;
}
void run(int n_in, double rate_in, double margin_in, int method_in)
{
n = n_in; rate = rate_in; margin = margin_in; method = method_in;
relation_vec.resize(relation_num);
for (int i = 0; i < relation_num; i++)
relation_vec[i].resize(n);
entity_vec.resize(entity_num);
for (int i = 0; i < entity_num; i++)
entity_vec[i].resize(n);
relation_tmp.resize(relation_num);
for (int i = 0; i < relation_tmp.size(); i++)
relation_tmp[i].resize(n);
entity_tmp.resize(entity_num);
for (int i = 0; i < entity_tmp.size(); i++)
entity_tmp[i].resize(n);
for (int i = 0; i < relation_num; i++)
{
for (int j = 0; j < n; j++)
relation_vec[i][j] = randn(0, 1.0 / n, -6 / sqrt(n), 6 / sqrt(n));
}
for (int i = 0; i < entity_num; i++)
{
for (int j = 0; j < n; j++)
entity_vec[i][j] = randn(0, 1.0 / n, -6 / sqrt(n), 6 / sqrt(n));
norm(entity_vec[i]);///限制每個實體向量的模在1以內(nèi),論文算法第5行
}
bfgs();
}
private:
int n, method;
double rate, margin;
double res;//?
//double count, count1;//?
//double belta;//?
vector<int> fb_h, fb_t, fb_r;//fb_h保存train.txt中每行的headentity的id
vector<vector<int> > feature;
vector<vector<double> > relation_vec, entity_vec;//embedding?
//相當于二維數(shù)組,第一維表示num(即id)授滓,第二維表示dim
vector<vector<double> > relation_tmp, entity_tmp;
double norm(vector<double>&a)
{//L2-norm of the embeddings of the entities is 1
double mo = vec_len(a);
if (mo > 1)
{
for (int i = 0; i < a.size(); i++)
a[i] = a[i] / mo;
}
return 0;
}
int rand_max(int x)//返回一[0,x)的整數(shù)
{
int j = (rand()*rand()) % x;
while (j < 0)
j += x;
return j;
}
void bfgs()
{
res = 0;//loss
int nbatches = 100;
int nepoch = 1000;
int batchsize = fb_h.size() / nbatches; //fb_h.size()==train.txt樣本個數(shù)
for (int epoch = 0; epoch < nepoch; epoch++)
{
res = 0;
for (int batch = 0; batch < nbatches; batch++)
{
relation_tmp = relation_vec;
entity_tmp = entity_vec;
for (int k = 0; k < batchsize; k++)
{
int i = rand_max(fb_h.size());
int j = rand_max(entity_num);//隨機選擇一entity id
double pr = 1000 * right_num[fb_r[i]] / (right_num[fb_r[i]] + left_num[fb_r[i]]);
//以概率pr替換三元組的尾實體顾犹,多對一關(guān)系更大概率替換尾實體
if (method == 0)//均勻采樣,將概率調(diào)為50%
pr = 500;//若均勻采樣,下面的if和else則隨機選擇替換頭實體還是尾實體
if (rand() % 1000 < pr)
{//替換尾實體褒墨,注意是小于pr
while (ok[make_pair(fb_h[i], fb_r[i])].count(j)>0)//有返回1,選擇負樣本尾實體
j = rand_max(entity_num);//若train.txt中包含,則換一個尾實體
train_kb(fb_h[i], fb_t[i], fb_r[i], fb_h[i], j, fb_r[i]);
}
else
{//替換頭實體
while (ok[make_pair(j, fb_r[i])].count(fb_t[i])>0)
j = rand_max(entity_num);
train_kb(fb_h[i], fb_t[i], fb_r[i], j, fb_t[i], fb_r[i]);
//計算loss擎宝,梯度下降郁妈,tmp中的值改變
}
//norm(relation_tmp[fb_r[i]]);我覺得這一行多余的,Transe論文中不需要限制relation
//由于entitytmp的值改變绍申,重新限制L2-norm of the embeddings of the entities is 1
norm(entity_tmp[fb_h[i]]);
norm(entity_tmp[fb_t[i]]);
norm(entity_tmp[j]);
}
relation_vec = relation_tmp;
entity_vec = entity_tmp;
}
cout << "epoch:" << epoch << ' ' << res << endl;
//將向量寫入文件
FILE* f2 = fopen(("relation2vec." + version).c_str(), "w");
FILE* f3 = fopen(("entity2vec." + version).c_str(), "w");
for (int i = 0; i < relation_num;i++)
{
for (int j = 0; j<n; j++)
fprintf(f2, "%.6lf\t", relation_vec[i][j]);
fprintf(f2, "\n");
}
for (int i = 0; i < entity_num; i++)
{
for (int j = 0; j<n; j++)
fprintf(f3, "%.6lf\t", entity_vec[i][j]);
fprintf(f3, "\n");
}
fclose(f2);
fclose(f3);
}
}
double calc_sum(int h, int t, int r)
{//計算h+r-t
double sum = 0;
if (L1_flag)
for (int i = 0; i < n; i++)
sum += fabs(entity_vec[h][i] + entity_vec[r][i] - entity_vec[t][i]);//(h+r-t)L1
else
for (int i = 0; i < n; i++)
sum += sqr(entity_vec[h][i] + entity_vec[r][i] - entity_vec[t][i]);//(h+r-t)歐式距離的平方
return sum;
}
void gardient(int h_a, int t_a, int r_a, int h_b, int t_b, int r_b)
{//Loss=|h+r-t|L12+margin-|h+r-t|L12
for (int i = 0; i < n; i++)
{
//計算正樣本梯度
double x = 2 * (entity_vec[h_a][i] + entity_vec[r_a][i] - entity_vec[t_a][i]);
//這里計算的L2范數(shù)的平方
if (L1_flag)//以絕對值作為loss
{
if (x>0) x = 1;//(h+r-t>0)
else x = -1;
}
relation_tmp[r_a][i] -= rate*x;
entity_tmp[h_a][i] -= rate*x;
entity_tmp[t_a][i] -= rate*x*(-1);
//計算負樣本梯度
x = 2 * (entity_vec[h_b][i] + entity_vec[r_b][i] - entity_vec[t_b][i]);
if (L1_flag)//絕對值作為loss
{
if (x>0) x = 1;//(h+r-t>0)
else x = -1;
}
relation_tmp[r_b][i] -=-1*rate*x;//注意要乘-1
entity_tmp[h_b][i] -= -1*rate*x;
entity_tmp[t_b][i] -= -1*rate*x*(-1);
}
}
void train_kb(int h_a,int t_a,int r_a,int h_b,int t_b,int r_b)
{//計算loss噩咪,梯度下降
//a是正樣本,b是負樣本
double posLoss = calc_sum(h_a, t_a, r_a);
double negLoss = calc_sum(h_b, t_b, r_b);
if (posLoss + margin - negLoss > 0)
{//<=0時极阅,loss=0
res += margin + posLoss - negLoss;
gardient( h_a, t_a, r_a, h_b, t_b, r_b);//更新梯度
}
}
};
Train train;
void prepare()
{
int mycount=0;//記錄讀取進度
FILE* f1 = fopen("../data/FB15k/entity2id.txt", "r");
FILE* f2 = fopen("../data/FB15k/relation2id.txt", "r");
int x;
while (fscanf(f1, "%s%d", buf, &x) == 2)//==2指的是正確讀入的參數(shù)個數(shù)
{
mycount++;
if (mycount % 200 == 0)
cout << "讀取第"<<mycount<<"個entity" << endl;
string st = buf;
entity2id[st] = x;
id2entity[x] = st;
entity_num++;
}
mycount = 0;
while (fscanf(f2, "%s%d", buf, &x) == 2)
{
mycount++;
if (mycount % 200 == 0)
cout << "讀取第" << mycount << "個relation" << endl;
string st = buf;
relation2id[st] = x;
id2relation[x] = st;
relation_num++;
}
mycount = 0;
FILE* f_kb = fopen("../data/FB15k/train.txt", "r");
while (fscanf(f_kb, "%s", buf) == 1)
{
mycount++;
if (mycount % 1000 == 0)
cout << "讀取第" << mycount << "個train樣本" << endl;
string s1 = buf;//entity1
fscanf(f_kb, "%s", buf);
string s2 = buf;//entity2
fscanf(f_kb, "%s", buf);
string s3 = buf;//relation
if (entity2id.count(s1) == 0)
cout << "miss entity:" << s1 << endl;
if (entity2id.count(s2) == 0)
cout << "miss entity:" << s2 << endl;
if (relation2id.count(s3) == 0)
{//若缺少此relation胃碾,補上
relation2id[s3] = relation_num;
relation_num++;
}
left_entity[relation2id[s3]][entity2id[s1]]++;
right_entity[relation2id[s3]][entity2id[s2]]++;
train.add(entity2id[s1], entity2id[s2], relation2id[s3]);//將h,r,t聯(lián)系起來
}
for (int i = 0; i < relation_num; i++)
{
double sum1 = 0, sum2 = 0;
for (map<int, int>::iterator it = left_entity[i].begin(); it != left_entity[i].end(); it++)
{
sum1++;//此relation下不同的headentity的個數(shù)
sum2 = sum2 + it->second;
}
left_num[i] = sum2 / sum1;
}
for (int i = 0; i < relation_num; i++)
{
double sum1 = 0, sum2 = 0;
for (map<int, int>::iterator it = right_entity[i].begin(); it != right_entity[i].end(); it++)
{
sum1++; sum2 = sum2 + it->second;
}
right_num[i] = sum2 / sum1;
}
cout << "relation_num=" << relation_num << endl;
cout << "entity_num=" << entity_num << endl;
fclose(f_kb);
}
int ArgPos(char *str, int argc, char **argv)
{
int i;
for ( i = 1; i < argc; i++)
{
if (!strcmp(str, argv[i]))//若兩者相同
{
if (i == argc - 1)
{
cout << "Argument missing for " << str << endl;
exit(1);
}
return i;
}
}
return -1;
}
int main(int argc, char **argv)
{
//運行方法
//D:\codes\vs\kb2e\Debug>kb2e.exe -size 111 -margin 22 -method 3
//上述argc=7
srand((unsigned)time(NULL));//提供隨機數(shù)種子
int method = 1;//1表示伯努利采樣
int n = 100;//dim
double rate = 0.001;//lr
double margin = 1;
int i;
if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);//atoi 字符串轉(zhuǎn)int
if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
cout << "dim="<<n << "margin=" << margin;
if (method)
version = "bern";//伯努利采樣,詳見transh
else
version = "unif";//均勻采樣
cout << "method = " << version << endl;
prepare();
train.run(n, rate, margin, method);
return 0;
}
先讀取entity.txt筋搏,relation.txt文件處理數(shù)據(jù)仆百,構(gòu)造entity2id, relation2id;
再讀取train.txt文件將h,r,t聯(lián)系起來奔脐,同時計算tph和hpt俄周。
訓練時吁讨,對每個epoch,每個batch峦朗,每個樣本:
計算正樣本loss建丧,計算負樣本loss,梯度下降波势,更新embedding翎朱;
最后將embedding寫入文件
TestTransE
#include<iostream>
#include<cstring>
#include<cstdio>
#include<map>
#include<vector>
#include<string>
#include<ctime>
#include<algorithm>
#include<cmath>
#include<cstdlib>
using namespace std;
//bool debug = false;
bool L1_flag = 1;//L1范數(shù),0表示L2
string version;
char buf[100000], buf1[100000];
int entity_num, relation_num;
int n = 100;
map<string, int> relation2id, entity2id;
map<int, string>id2relation, id2entity;
double vec_len(vector<double>a)
{
double res = 0;
for (int i = 0; i < a.size(); i++)
res += a[i]*a[i];
return sqrt(res);
}
double sqr(double x)
{
return x*x;
}
double cmp(pair<int, double> a, pair<int, double>b)
{//從小到大排序
return a.second < b.second;
}
class Test{
vector<vector<double>> relation_vec, entity_vec;
//vector<int>h, r, t;
vector<int>fb_h, fb_r, fb_t;
map<pair<int, int>, map<int, int> >ok;
double res;
public:
void add(int h, int t, int r, bool flag)
{
if (flag)
{
fb_h.push_back(h);
fb_t.push_back(t);
fb_r.push_back(r);
ok[make_pair(h, r)][t] = 1;
}
}
double cal_sum(int h, int t, int r)
{//KB2E計算的負值,想不通
double sum = 0;
if (L1_flag)//L1
for (int i = 0; i < n; i++)
sum += fabs(entity_vec[h][i] + relation_vec[r][i] - entity_vec[t][i]);
else
for (int i = 0; i < n; i++)
sum += sqr(entity_vec[h][i] + relation_vec[r][i] - entity_vec[t][i]);
return sum;
}
int rand_max(int x)
{
int res = (rand()*rand()) % x;
if (res<0)
res += x;
return res;
}
void run()
{
FILE* f1 = fopen(("relation2vec." + version).c_str(), "r");
FILE* f3 = fopen(("entity2vec." + version).c_str(), "r");
cout <<"relation_num="<< relation_num << ', ' << "entity_num="<<entity_num << endl;
relation_vec.resize(relation_num);//relation_num應(yīng)該比relation_vec的size要大吧
for (int i = 0; i < relation_num; i++)
{//讀取文件中的relation embedding尺铣,保存到relation_vec中
relation_vec[i].resize(n);
for (int j = 0; j < n; j++)
fscanf(f1, "%lf", &relation_vec[i][j]);
}
entity_vec.resize(entity_num);
for (int i = 0; i < entity_num; i++)
{//讀取文件中的entity embedding拴曲,保存到entity_vec中
entity_vec[i].resize(n);
for (int j = 0; j < n; j++)
fscanf(f3, "%lf", &entity_vec[i][j]);
if (vec_len(entity_vec[i]) - 1>1e-3)
cout << "wrong_entity" << i << ' ' << vec_len(entity_vec[i]) << endl;
}
fclose(f1); fclose(f3);
//map<int, int> rel_num;//relationid,number
double hrank = 0,hrank_filter=0;//替換頭實體排名
double hrank10num = 0,hrank10numfilter=0;//替換頭實體rank10
double m = 0;//正確樣本個數(shù) 用于filter
for (int testid = 0; testid < fb_h.size(); testid++)
{//對test.txt中的每一行
int h = fb_h[testid];//head_entity id
int t = fb_t[testid];
int rel = fb_r[testid];
//rel_num[rel] += 1;
vector<pair<int, double>>a;//head_entityid,score
for (int i = 0; i < entity_num; i++)
{
double score=cal_sum(i, t, rel);//頭實體被每個實體替代
a.push_back(make_pair(i, score));
}
sort(a.begin(), a.end(), cmp);//升序排序
m = 0;
for (int i = a.size() - 1; i >= 0; i--)
{
if (ok[make_pair(a[i].first, rel)].count(t) > 0)//存在正確樣本
m++;
if (a[i].first == h)//正確樣本
{
hrank += a.size() - i;//raw 排名
hrank_filter += a.size() - i - m;
if (a.size() - i < 10)
hrank10num += 1;
if (a.size() - i - m < 10)
hrank10numfilter += 1;
}
}
a.clear();
}
cout << "替換頭實體raw mean rank=" << hrank / fb_h.size();
cout << "替換頭實體raw rank10=" << hrank10num / fb_h.size();
cout << "替換頭實體filter mean rank=" << hrank_filter / fb_h.size();
cout << "替換頭實體filter rank10=" << hrank10numfilter / fb_h.size();
}
};
Test test;
void prepare()
{
FILE* f1 = fopen("../data/FB15k/entity2id.txt","r");
FILE* f2 = fopen("../data/FB15k/relation2id.txt", "r");
int x;
while (fscanf(f1, "%s%d", buf, &x) == 2)
{
string s = buf;
entity2id[s] = x;
id2entity[x] = s;
entity_num++;
}
while (fscanf(f2, "%s%d", buf, &x) == 2)
{
string s = buf;
relation2id[s] = x;
id2entity[x] = s;
relation_num++;
}
FILE* f_kb = fopen("../data/FB15k/test.txt", "r");
while (fscanf(f_kb, "%s", buf) == 1)
{
string s1 = buf;//h
fscanf(f_kb, "%s", buf);
string s2 = buf;//t
fscanf(f_kb, "%s", buf);
string s3 = buf;//r
if (entity2id.count(s1)==0)
cout << "miss entity:" << s1 << endl;
if (entity2id.count(s2) == 0)
cout << "miss entity:" << s2 << endl;
if (relation2id.count(s3) == 0)
{
cout << "miss relation:" << s3 << endl;
relation2id[s3] = relation_num;
relation_num++;
}
test.add(entity2id[s1],entity2id[s2],entity2id[s3],true);
}
fclose(f_kb);
FILE* f_kb1 = fopen("../data/FB15k/train.txt", "r");
while (fscanf(f_kb1, "%s", buf) == 1)
{
string s1 = buf;
fscanf(f_kb1, "%s", buf);
string s2 = buf;
fscanf(f_kb1, "%s", buf);
string s3 = buf;
if (entity2id.count(s1) == 0)
cout << "miss entity:" << s1 << endl;
if (entity2id.count(s2) == 0)
cout << "miss entity:" << s2 << endl;
if (relation2id.count(s3) == 0)
{
relation2id[s3] = relation_num;
relation_num++;
}
test.add(entity2id[s1], entity2id[s2], entity2id[s3], true);
//應(yīng)該為true,論文中提到移除train迄埃,val疗韵,test“錯誤”三元組
}
fclose(f_kb1);
FILE* f_kb2 = fopen("../data/FB15k/valid.txt", "r");
while (fscanf(f_kb2, "%s", buf) == 1)
{
string s1 = buf;
fscanf(f_kb2, "%s", buf);
string s2 = buf;
fscanf(f_kb2, "%s", buf);
string s3 = buf;
if (entity2id.count(s1) == 0)
{
cout << "miss entity:" << s1 << endl;
}
if (entity2id.count(s2) == 0)
{
cout << "miss entity:" << s2 << endl;
}
if (relation2id.count(s3) == 0)
{
relation2id[s3] = relation_num;
relation_num++;
}
test.add(entity2id[s1], entity2id[s2], relation2id[s3], true);
//應(yīng)該為true,論文中提到移除train侄非,val蕉汪,test“錯誤”三元組
}
fclose(f_kb2);
}
int main(int argc, char** argv)
{
if (argc < 2)
return 0;
else
{
version = argv[1];
prepare();
test.run();
}
}