具體的應用場景是將BERT的CKPT轉化為pytorch的model熔萧。
把下面的代碼保存為transfer.py
比如下述訓練好的check point假褪。
執(zhí)行:transfer.py model.ckpt-2543 model/rf-bert
前面是訓練第2543次的check point 后面是保存的路徑,就會得到:
transfer.py:
import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np
def tr(v):
# tensorflow weights to pytorch weights
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3,2,0,1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def read_ckpt(ckpt):
# https://github.com/tensorflow/tensorflow/issues/1823
reader = tf.train.NewCheckpointReader(ckpt)
weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
pyweights = {k: tr(v) for (k, v) in weights.items()}
return pyweights
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
parser.add_argument("infile", type=str,
help="Path to the ckpt.") # ***model.ckpt-22177***
parser.add_argument("outfile", type=str, nargs='?', default='',
help="Output file (inferred if missing).")
args = parser.parse_args()
if args.outfile == '':
args.outfile = os.path.splitext(args.infile)[0] + '.h5'
outdir = os.path.dirname(args.outfile)
if not os.path.exists(outdir):
os.makedirs(outdir)
weights = read_ckpt(args.infile)
dd.io.save(args.outfile, weights)
————————————————
版權聲明:本文為CSDN博主「 楊楊」的原創(chuàng)文章杂抽,遵循 CC 4.0 BY-SA 版權協(xié)議诈唬,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/weixin_42699651/article/details/88932670