[OpenKE] Knowledge Embedding PyTorch版本

安裝

1.  Install PyTorch
2.  Clone the OpenKE-PyTorch branch:
    $ git clone -b OpenKE-PyTorch [https://github.com/thunlp/OpenKE](https://github.com/thunlp/OpenKE)
    $ cd OpenKE
3.  Compile C++ files
    $ bash make.sh

訓(xùn)練

需要三個文件

  • train2id.txt 訓(xùn)練文件卸奉,第一行是三元組的數(shù)量钝诚,接下來的數(shù)據(jù)格式是(e1,e2,rel),需要注意的是,e1,和e2是實(shí)體的編號榄棵,rel是關(guān)系的編號凝颇,其對應(yīng)關(guān)系存放在文件entity2id.txtrelation2id.txt
  • entity2id.txt 第一行是實(shí)體的數(shù)目疹鳄,接下來的每一行是實(shí)體對應(yīng)相關(guān)的id拧略。
  • relation2id.txt 第一行是關(guān)系的數(shù)目,接下來的每一行是關(guān)系對應(yīng)的id瘪弓。

測試

需要5個文件垫蛆,除過上面的三個,還需要

  • test2id.txt 第一行是測試三元組的數(shù)目腺怯,接下來的是(e1,e2,rel).
  • valid2id.txt 驗(yàn)證數(shù)據(jù)集袱饭,第一行是驗(yàn)證三元組的個數(shù),接下來的行是驗(yàn)證數(shù)據(jù)(e1,e2,rel)
  • type_constrain.txt 類型限制文件呛占,表示關(guān)系只能和特定類型的頭實(shí)體和尾實(shí)體結(jié)合虑乖,文件第一行是關(guān)系的數(shù)目,接下來的行是每一種關(guān)系的類型限制晾虑,例如: 某個relation的id是1200 疹味,他的頭實(shí)體(head entities)有四種類型 3123 1034,58,5733 這個relation同時又有4種類型的尾實(shí)體12123,4388,11087,11088 帜篇,這種n對n的關(guān)系可以通過through n-n.py in folder benchmarks/FB15K來查看糙捺。

Quick Start

import config
import models
import json
import numpy as np


con = config.Config()
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_work_threads(4)
con.set_train_times(500)
con.set_nbatches(100)
con.set_alpha(0.001)
con.set_margin(1.0)
con.set_bern(0)
con.set_dimension(50)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")

#Models will be exported via tf.Saver() automatically.
con.set_export_files("./res/model.vec.tf", 0)
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/embedding.vec.json")
#Initialize experimental settings.
con.init()
#Set the knowledge embedding model
con.set_model(models.TransE)
#Train the model.
con.run()   
步驟1 加載數(shù)據(jù)

這個文件夾下面有三個文件train2id.txt,entity2id.txt,relation2id.txt

con.set_in_path("benchmarks/FB15K/")

可以分配幾個threads進(jìn)行采樣sample positive and negative cases.

con.set_work_threads(8)
步驟2, 設(shè)置訓(xùn)練參數(shù)

最大訓(xùn)練輪數(shù)笙隙,batchSeize继找,實(shí)體和關(guān)系的維數(shù),

con.set_train_times(500)
con.set_nbatches(100)
con.set_alpha(0.5)
con.set_dimension(200)
con.set_margin(1)

對于負(fù)采樣逃沿,我們可以把正常實(shí)體和關(guān)系拆分來構(gòu)造negative triples婴渡, set_bern(0)是傳統(tǒng)的采樣方法,set_bern(1)表示使用 (Wang et al. 2014) denoted as "bern"提出的構(gòu)造方法凯亮,set_ent_neg_rate是設(shè)置實(shí)體的負(fù)采樣率边臼,set_rel_neg_rate是設(shè)置關(guān)系的負(fù)采樣率。

con.set_bern(0)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)

設(shè)置優(yōu)化方法

con.set_optimizer("SGD")
步驟3假消,輸出結(jié)果

模型參數(shù)每隔幾輪就會使用torch.save()自動的保存下來柠并,同時最終的結(jié)果會保存成json 文件的形式。

con.set_export_files("./res/model.vec.pt")
con.set_out_files("./res/embedding.vec.json")
步驟4 訓(xùn)練模型
con.init()
con.set_model(models.TransE)
con.run()
步驟5 測試
測試任務(wù)

link prediction任務(wù):用于預(yù)測三元組中缺失的關(guān)系或者尾實(shí)體富拗,對于測試的三元組臼予,我們replace掉了head/tail 實(shí)體,并以降序的順序給出預(yù)測出實(shí)體的得分啃沪。平均的指標(biāo)有:

  • MR:正確實(shí)體的平均rank粘拾。
  • MRR: the average of the reciprocal ranks of correct entities。
  • Hit@N:正確實(shí)體在top-N的比率

三元組分類任務(wù):判斷一個三元組(h,r,t)是否正確创千,這是一個二分類問題缰雇。
預(yù)測頭實(shí)體任務(wù): 預(yù)測topk個可能的頭實(shí)體,所有的頭實(shí)體用id表示

def predict_head_entity(self, t, r, k):
    r'''This mothod predicts the top k head entities given tail entity and relation.
    
    Args: 
        t (int): tail entity id
        r (int): relation id
        k (int): top k head entities
    
    Returns:
        list: k possible head entity ids        
    '''
    self.init_link_prediction()
    if self.importName != None:
        self.restore_pytorch()
    test_h = np.array(range(self.entTotal))
    test_r = np.array([r] * self.entTotal)
    test_t = np.array([t] * self.entTotal)
    res = self.trainModel.predict(test_h, test_t, test_r).data.numpy().reshape(-1).argsort()[:k]
    print(res)
    return res

預(yù)測尾實(shí)體:與預(yù)測頭實(shí)體相似追驴。
預(yù)測關(guān)系

def predict_relation(self, h, t, k):
    r'''This methods predict the relation id given head entity and tail entity.
    
    Args:
        h (int): head entity id
        t (int): tail entity id
        k (int): top k relations
    
    Returns:
        list: k possible relation ids
    '''
    self.init_link_prediction()
    if self.importName != None:
        self.restore_pytorch()
    test_h = np.array([h] * self.relTotal)
    test_r = np.array(range(self.relTotal))
    test_t = np.array([t] * self.relTotal)
    res = self.trainModel.predict(test_h, test_t, test_r).data.numpy().reshape(-1).argsort()[:k]
    print(res)
    return res

預(yù)測三元組:給一個三元組械哟,這個函數(shù)告訴我們是否這個三元組是否正確,如果threshold沒有給出殿雪,那么函數(shù)從驗(yàn)證集中計算出這個關(guān)系的threshold暇咆。

def predict_triple(self, h, t, r, thresh = None):
    r'''This method tells you whether the given triple (h, t, r) is correct of wrong

    Args:
        h (int): head entity id
        t (int): tail entity id
        r (int): relation id
        thresh (fload): threshold for the triple
    '''
    self.init_triple_classification()
    if self.importName != None:
        self.restore_pytorch()  
    res = self.trainModel.predict(np.array([h]), np.array([t]), np.array([r])).data.numpy()
    if thresh != None:
        if res < thresh:
                        print("triple (%d,%d,%d) is correct" % (h, t, r))
                else:
                        print("triple (%d,%d,%d) is wrong" % (h, t, r)) 
        return
    self.lib.getValidBatch(self.valid_pos_h_addr, self.valid_pos_t_addr, self.valid_pos_r_addr, self.valid_neg_h_addr, self.valid_neg_t_addr, self.valid_neg_r_addr)
    res_pos = self.trainModel.predict(self.valid_pos_h, self.valid_pos_t, self.valid_pos_r)
    res_neg = self.trainModel.predict(self.valid_neg_h, self.valid_neg_t, self.valid_neg_r)
    self.lib.getBestThreshold(self.relThresh_addr, res_pos.data.numpy().__array_interface__['data'][0], res_neg.data.numpy().__array_interface__['data'][0])
    if res < self.relThresh[r]:
        print("triple (%d,%d,%d) is correct" % (h, t, r))
    else: 
        print("triple (%d,%d,%d) is wrong" % (h, t, r))
具體實(shí)施

第一步是導(dǎo)入數(shù)據(jù)集,然后配置參數(shù)丙曙,然后設(shè)置模型參數(shù)和測試的模型爸业,例如我們需要測試TransE:有三種方法來測試模型。

  1. 設(shè)置導(dǎo)入文件河泳,OpenKE-PyTorch會自動的通過torch.load()來加載模型沃呢。
import config
import models
import numpy as np
import json

con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.set_import_files("./res/model.vec.pt")
con.init()
con.set_model(models.TransE)
con.test()
  1. 從json文件中讀取模型參數(shù),手動的加載參數(shù)。
import config
import models
import numpy as np
import json

con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.init()
con.set_model(models.TransE)
f = open("./res/embedding.vec.json", "r")
content = json.loads(f.read())
f.close()
con.set_parameters(content)
con.test()
  1. 使用torch.load()手動的加載模型拆挥。
import config
import models
import numpy as np
import json

con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.init()
con.set_model(models.TransE)
con.import_variables("./res/model.vec.pt")
con.test()
獲取embedding 矩陣

有四種方式來獲取embedding矩陣

  1. 設(shè)置import 文件那么OpenKE 會自動的使用torch.load()加載模型.
    使用con.get_parameters()函數(shù)來獲得list類型的embedding矩陣薄霜,可以通過con.get_parameters("numpy")獲得numpy類型的參數(shù)。
import json
import numpy as py
import config
import models
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.set_import_files("./res/model.vec.pt")
con.init()
con.set_model(models.TransE)
# Get the embeddings (numpy.array)
embeddings = con.get_parameters("numpy")
# Get the embeddings (python list)
embeddings = con.get_parameters()
  1. 從json文件中獲取
import json
import numpy as py
import config
import models
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.init()
con.set_model(models.TransE)
f = open("./res/embedding.vec.json", "r")
embeddings = json.loads(f.read())
f.close()
  1. 手動的使用torch.load()加載模型纸兔,但是獲取embedding的方式還是一樣的惰瓜。
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.test_link_prediction(True)
con.test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.init()
con.set_model(models.TransE)
con.import_variables("./res/model.vec.pt")
# Get the embeddings (numpy.array)
embeddings = con.get_parameters("numpy")
# Get the embeddings (python list)
embeddings = con.get_parameters()
  1. 從一個訓(xùn)練好的模型中立即拿到embedding。
#Models will be exported via tf.Saver() automatically.
con.set_export_files("./res/model.vec.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/embedding.vec.json")
#Initialize experimental settings.
con.init()
#Set the knowledge embedding model
con.set_model(models.TransE)
#Train the model.
con.run()
#Get the embeddings (numpy.array)
embeddings = con.get_parameters("numpy")
#Get the embeddings (python list)
embeddings = con.get_parameters()

接口

Config的接口
  • def set_alpha(alpha = 0.001) 設(shè)置學(xué)習(xí)率
  • def set_lmbda(lmbda = 0.0)汉矿,設(shè)置正則化前面的系數(shù)崎坊。

To set the degree of the regularization on the parameters

  • set_train_times(self, times) 相當(dāng)于設(shè)置epoch
  • def sampling() 從正樣本和負(fù)樣本中采樣一個batch
  • def set_in_path(self, path) 讀取數(shù)據(jù)
  • def set_out_files(self, path) 當(dāng)訓(xùn)練結(jié)束的時候?qū)⒛P偷膮?shù)變成json文件存儲下來。
  • def set_import_files(self, path) 模型所有的參數(shù)都可以用這個文件夾里面恢復(fù)洲拇。
  • def set_export_steps(self, steps) 每隔多少步存儲一次文件
  • def save_pytorch(self) 使用torch.save來保存模型
  • def import_variables(self, path = None) 恢復(fù) tensorflow模型奈揍,相當(dāng)于restore_tensorflow()
  • def set_parameters(self, lists) 從jaon 文件中加載parameters
  • def get_parameters(self, mode = "numpy") 獲取模型參數(shù)也就是每個實(shí)體的embedding
  • def set_model(model) 表示使用什么model進(jìn)行knowledge embedding
  • def set_log_on(flag = 1) 如果設(shè)置為1那么表示會打印loss函數(shù)

The framework will print loss values during training if flag = 1

class Config(object):
        
    #To set the learning rate
    def set_alpha(alpha = 0.001)
    
    #To set the degree of the regularization on the parameters
    def set_lmbda(lmbda = 0.0)
    
    #To set the gradient descent optimization algorithm (SGD, Adagrad, Adadelta, Adam)
    def set_optimizer(optimizer = "SGD")
    
    #To set the data traversing rounds
    def set_train_times(self, times)
    
    #To split the training triples into several batches, nbatches is the number of batches
    def set_nbatches(nbatches = 100)
    
    #To set the margin for the loss function
    def set_margin(margin = 1.0)
    
    #To set the dimensions of the entities and relations at the same time
    def set_dimension(dim)
    
    #To set the dimensions of the entities
    def set_ent_dimension(self, dim)
    
    #To set the dimensions of the relations
    def set_rel_dimension(self, dim)
    
    #To allocate threads for each batch sampling
    def set_work_threads(threads = 1)
    
    #To set negative sampling algorithms, unif (bern = 0) or bern (bern = 1)
    def set_bern(bern = 1)
    
    #For each positive triple, we construct rate negative triples by corrupt the entity
    def set_ent_neg_rate(rate = 1)
    
    #For each positive triple, we construct rate negative triples by corrupt the relation
    def set_rel_neg_rate(rate = 0)
    
    #To sample a batch of training triples, including positive and negative ones.
    def sampling()

    #To import dataset from the benchmark folder
    def set_in_path(self, path)
    
    #To export model parameters to json files when training completed
    def set_out_files(self, path)
    
    #To set the import files, all parameters can be restored from the import files
    def set_import_files(self, path)
    
    #To set the export file of model paramters, and export results every few rounds
    def set_export_files(self, path, steps = 0)

    #To export results every few rounds
    def set_export_steps(self, steps)

    #To save model via torch.save()
    def save_pytorch(self)

    #To restore model via torch.load()
    def restore_pytorch(self)

    #To export model paramters, when path is none, equivalent to save_tensorflow()
    def export_variables(self, path = None)

    #To import model paramters, when path is none, equivalent to restore_tensorflow()
    def import_variables(self, path = None)
    
    #To export model paramters to designated path
    def save_parameters(self, path = None)

    #To manually load parameters which are read from json files
    def set_parameters(self, lists)
    
    #To get model paramters, if using mode "numpy", you can get np.array , else you can get python lists
    def get_parameters(self, mode = "numpy")

    #To set the knowledge embedding model
    def set_model(model)
    
    #The framework will print loss values during training if flag = 1
    def set_log_on(flag = 1)

    #This is essential when testing
    def test_link_prediction(True)
    def test_triple_classification(True)
模型的接口
class Model(object)

    # in_batch = True, return [positive_head, positive_tail, positive_relation]
    # The shape of positive_head is [batch_size, 1]
    # in_batch = False, return [positive_head, positive_tail, positive_relation]
    # The shape of positive_head is [batch_size]
    get_positive_instance(in_batch = True)
    
    # in_batch = True, return [negative_head, negative_tail, negative_relation]
    # The shape of positive_head is [batch_size, negative_ent_rate + negative_rel_rate]
    # in_batch = False, return [negative_head, negative_tail, negative_relation]
    # The shape of positive_head is [(negative_ent_rate + negative_rel_rate) * batch_size]      
    get_negative_instance(in_batch = True)

    # in_batch = True, return all training instances with the shape [batch_size, (1 + negative_ent_rate + negative_rel_rate)]
    # in_batch = False, return all training instances with the shape [(negative_ent_rate + negative_rel_rate + 1) * batch_size]
    def get_all_instance(in_batch = False)

    # in_batch = True, return all training labels with the shape [batch_size, (1 + negative_ent_rate + negative_rel_rate)]
    # in_batch = False, return all training labels with the shape [(negative_ent_rate + negative_rel_rate + 1) * batch_size]
    # The positive triples are labeled as 1, and the negative triples are labeled as -1
    def get_all_labels(in_batch = False)
    
    #To calulate the loss
    def forward(self)

    # To define loss functions for knowledge embedding models
    def loss_func()
    
    # To define the prediction functions for knowledge embedding models
    def predict(self)

    def __init__(config)

#The implementation for TransE
class TransE(Model)

#The implementation for TransH  
class TransH(Model)

#The implementation for TransR
class TransR(Model)

#The implementation for TransD
class TransD(Model)

#The implementation for RESCAL
class RESCAL(Model)

#The implementation for DistMult
class DistMult(Model)                   

#The implementation for ComplEx
class ComplEx(Model)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末曲尸,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子男翰,更是在濱河造成了極大的恐慌另患,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,941評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件蛾绎,死亡現(xiàn)場離奇詭異昆箕,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)租冠,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,397評論 3 395
  • 文/潘曉璐 我一進(jìn)店門鹏倘,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事≡で眩” “怎么了斋陪?”我有些...
    開封第一講書人閱讀 165,345評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么涉馅? 我笑而不...
    開封第一講書人閱讀 58,851評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮黄虱,結(jié)果婚禮上稚矿,老公的妹妹穿的比我還像新娘。我一直安慰自己捻浦,他們只是感情好晤揣,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,868評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著朱灿,像睡著了一般昧识。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上盗扒,一...
    開封第一講書人閱讀 51,688評論 1 305
  • 那天跪楞,我揣著相機(jī)與錄音,去河邊找鬼侣灶。 笑死甸祭,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的褥影。 我是一名探鬼主播池户,決...
    沈念sama閱讀 40,414評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了校焦?” 一聲冷哼從身側(cè)響起赊抖,我...
    開封第一講書人閱讀 39,319評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎寨典,沒想到半個月后熏迹,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,775評論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡凝赛,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,945評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了坛缕。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片墓猎。...
    茶點(diǎn)故事閱讀 40,096評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖赚楚,靈堂內(nèi)的尸體忽然破棺而出毙沾,到底是詐尸還是另有隱情,我是刑警寧澤宠页,帶...
    沈念sama閱讀 35,789評論 5 346
  • 正文 年R本政府宣布左胞,位于F島的核電站,受9級特大地震影響举户,放射性物質(zhì)發(fā)生泄漏烤宙。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,437評論 3 331
  • 文/蒙蒙 一俭嘁、第九天 我趴在偏房一處隱蔽的房頂上張望躺枕。 院中可真熱鬧,春花似錦供填、人聲如沸拐云。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,993評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽叉瘩。三九已至,卻和暖如春粘捎,著一層夾襖步出監(jiān)牢的瞬間薇缅,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,107評論 1 271
  • 我被黑心中介騙來泰國打工晌端, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留捅暴,地道東北人。 一個月前我還...
    沈念sama閱讀 48,308評論 3 372
  • 正文 我出身青樓咧纠,卻偏偏與公主長得像蓬痒,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子漆羔,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,037評論 2 355

推薦閱讀更多精彩內(nèi)容

  • 原文 機(jī)器學(xué)習(xí)(Machine Learning, ML)是一門多領(lǐng)域交叉學(xué)科梧奢,涉及概率論狱掂、統(tǒng)計學(xué)、逼近論亲轨、凸分析...
    readilen閱讀 3,898評論 0 41
  • 該文章為轉(zhuǎn)載文章趋惨,作者簡介:汪劍,現(xiàn)在在出門問問負(fù)責(zé)推薦與個性化惦蚊。曾在微軟雅虎工作器虾,從事過搜索和推薦相關(guān)工作。 T...
    名字真的不重要閱讀 5,269評論 0 3
  • 前面的文章主要從理論的角度介紹了自然語言人機(jī)對話系統(tǒng)所可能涉及到的多個領(lǐng)域的經(jīng)典模型和基礎(chǔ)知識蹦锋。這篇文章兆沙,甚至之后...
    我偏笑_NSNirvana閱讀 13,913評論 2 64
  • 剛上知乎看了一個討論舉報作弊正確與否的問題。 很想寫個評論或者答案莉掂,沒有寫葛圃。因?yàn)檫@個問題答案太多了,別人根本看不到...
    LackingDopamine閱讀 493評論 0 3
  • 明天就要開學(xué)了憎妙,思緒萬千库正,翻看著課程安排,心里是著急的厘唾,一下子加了這么多科目褥符,孩子們能適應(yīng)嗎?盡管暑假布置了...
    皮_小皮閱讀 223評論 3 6