代碼智能:問題與解法
在基于預訓練大模型引發(fā)自然語言處理革命的今天八毯,代碼智能技術也在迅速跟進發(fā)展话速。
那么泊交,代碼智能主要在做一些什么樣的事情呢云石?可能很多同學會有比較科幻的想法汹忠,比如程序員要失業(yè)了之類的。
但是铅乡,其實很多工作并沒有那么神秘隆判,非常基礎咬腕。那么我們用代碼智能要解決什么問題呢纽帖?
- 判斷兩段代碼是不是實現(xiàn)相似的功能
- 搜索跟當前代碼段最相似的代碼
- 檢測代碼是否有bug
- 自動修復代碼中的bug
- 給一段代碼自動寫注釋
- 根據(jù)文本推薦最相似的代碼段
- 根據(jù)文本生成代碼
看了之后是不是覺得更玄幻了懊直?這么困難的問題怎么搞得定?
誠實地講融撞,這其中的每個子問題都很困難,就算是人類學習起來也很困難致扯。
不過急前,正像是人類也是一步一步學會的一樣刨摩,機器也在不斷地進步澡刹。我們需要的不一定是萬能的機器神陆赋,也是和我們一樣普通的機器人,它們有很大的局限灾锯,但是它們可以幫助我們減輕不少工作量顺饮。
而且吟逝,最后一節(jié)我們將揭曉澎办,處理這么多如此復雜問題的方法,卻非常簡單琅绅,一把梭哈,我們只用一個模型就能搞定澎羞。
下面我們就詳細看一看這些問題的細節(jié)敛苇。
問題:克隆檢測 Clone Detection
萬地高樓平地起妆绞,代碼智能任務首先從克隆檢測做起。
所謂克隆檢測枫攀,就是尋找寫法和功能上相似的代碼括饶。
不要小看代碼重復,它會顯著地降低代碼智能訓練的有效性来涨。
我們看下圖,訓練集中有重復蹦掐,測試集中有重復技羔,它們的交集中仍然有重復,在論文《The Adverse Effects of Code Duplication in Machine Learning Models of Code》中有詳細的分析卧抗。
預測兩段代碼是否相似
以下的例子來自BigCloneBench數(shù)據(jù)集. 論文地址在:https://arxiv.org/pdf/2002.08653.pdf
下面我們舉幾個例子來看什么算相似:
代碼1:
private StringBuffer encoder(String arg) {
if (arg == null) {
arg = "";
}
MessageDigest md5 = null;
try {
md5 = MessageDigest.getInstance("MD5");
md5.update(arg.getBytes(SysConstant.charset));
} catch (Exception e) {
e.printStackTrace();
}
return toHex(md5.digest());
}
代碼2:
public String kodetu(String testusoila) {
MessageDigest md = null;
try {
md = MessageDigest.getInstance("SHA");
md.update(testusoila.getBytes("UTF-8"));
} catch (NoSuchAlgorithmException e) {
new MezuLeiho("Ez da zifraketa algoritmoa aurkitu", "Ados", "Zifraketa Arazoa", JOptionPane.ERROR_MESSAGE);
e.printStackTrace();
} catch (UnsupportedEncodingException e) {
new MezuLeiho("Errorea kodetzerakoan", "Ados", "Kodeketa Errorea", JOptionPane.ERROR_MESSAGE);
e.printStackTrace();
}
byte raw[] = md.digest();
String hash = (new BASE64Encoder()).encode(raw);
return hash;
}
代碼2的字符串是用巴斯克語寫的藤滥。它們用的算法也有區(qū)別,判空和異常處理也有不同颗味,但是我們認為它們是很類似的超陆,屬于克隆識別認為相同或高度相似的。
我們再看一對例子:
代碼1:
public static void test(String args[]) {
int trace;
int bytes_read = 0;
int last_contentLenght = 0;
try {
BufferedReader reader;
URL url;
url = new URL(args[0]);
URLConnection istream = url.openConnection();
last_contentLenght = istream.getContentLength();
reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
System.out.println(url.toString());
String line;
trace = t2pNewTrace();
while ((line = reader.readLine()) != null) {
bytes_read = bytes_read + line.length() + 1;
t2pProcessLine(trace, line);
}
t2pHandleEventPairs(trace);
t2pSort(trace, 0);
t2pExportTrace(trace, new String("pngtest2.png"), 1000, 700, (float) 0, (float) 33);
t2pExportTrace(trace, new String("pngtest3.png"), 1000, 700, (float) 2.3, (float) 2.44);
System.out.println("Press any key to contiune read from stream !!!");
System.out.println(t2pGetProcessName(trace, 0));
System.in.read();
istream = url.openConnection();
if (last_contentLenght != istream.getContentLength()) {
istream = url.openConnection();
istream.setRequestProperty("Range", "bytes=" + Integer.toString(bytes_read) + "-");
System.out.println(Integer.toString(istream.getContentLength()));
reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
while ((line = reader.readLine()) != null) {
System.out.println(line);
t2pProcessLine(trace, line);
}
} else System.out.println("File not changed !");
t2pDeleteTrace(trace);
} catch (MalformedURLException e) {
System.out.println("MalformedURLException !!!");
} catch (IOException e) {
System.out.println("File not found " + args[0]);
}
;
}
代碼2:
private static String loadUrlToString(String a_url) throws IOException {
URL l_url1 = new URL(a_url);
BufferedReader br = new BufferedReader(new InputStreamReader(l_url1.openStream()));
String l_content = "";
String l_ligne = null;
l_content = br.readLine();
while ((l_ligne = br.readLine()) != null) {
l_content += AA.SL + l_ligne;
}
return l_content;
}
這個雖然沒有涉及小語種,但是明顯代碼長度差異巨大时呀。不過张漂,我們仍然認為它們是相似的。
我們看一對不相似的吧:
代碼1:
private void setNodekeyInJsonResponse(String service) throws Exception {
String filename = this.baseDirectory + service + ".json";
Scanner s = new Scanner(new File(filename));
PrintWriter fw = new PrintWriter(new File(filename + ".new"));
while (s.hasNextLine()) {
fw.println(s.nextLine().replaceAll("NODEKEY", this.key));
}
s.close();
fw.close();
(new File(filename + ".new")).renameTo(new File(filename));
}
代碼2:
public void transform(String style, String spec, OutputStream out) throws IOException {
URL url = new URL(rootURL, spec);
InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));
transform(style, in, out);
in.close();
}
不相似的就不解釋了谨娜。
BigCloneBench數(shù)據(jù)集航攒,就是提供了兩段代碼,以及它們是否相似的人工打標的結果趴梢。
數(shù)據(jù)分為train.txt, valid.txt, test.txt三個集合漠畜,它們的格式都是同樣的:
idx1 idx2 0/1
其中idx1和idx2是兩段代碼在data.jsonl中的索引值,最后一個是它們是否相似的人工打標的值坞靶。
代碼都保存在data.jsonl中憔狞,格式為:
{"func":"代碼","idx":"idx值"}
我們以訓練集train.txt為例,其前兩行是這樣的:
13988825 8660836 0
80378 18548122 1
13988825在data.jsonl中對應的結構是這樣的:
{"func": " private void setNodekeyInJsonResponse(String service) throws Exception {\n String filename = this.baseDirectory + service + \".json\";\n Scanner s = new Scanner(new File(filename));\n PrintWriter fw = new PrintWriter(new File(filename + \".new\"));\n while (s.hasNextLine()) {\n fw.println(s.nextLine().replaceAll(\"NODEKEY\", this.key));\n }\n s.close();\n fw.close();\n (new File(filename + \".new\")).renameTo(new File(filename));\n }\n", "idx": "13988825"}
8660836對應的是:
{"func": " public void transform(String style, String spec, OutputStream out) throws IOException {\n URL url = new URL(rootURL, spec);\n InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n transform(style, in, out);\n in.close();\n }\n", "idx": "8660836"}
而它們的結果是不相似彰阴。大家看到瘾敢,這個例子就是剛才上面我們寫的第三個例子。
搜索跟當前代碼段語義最相似的代碼段
這個我們使用北大李戈李師團隊的POJ-104數(shù)據(jù)集尿这。
這個數(shù)據(jù)集需要到https://drive.google.com/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU去下載簇抵。
每個代碼段用一個index來描述,然后code字段是完整的代碼射众。我們來看個例子:
{
"label":"1",
"index":"0",
"code":"
int f(int a,int x)
{
int count=1,i;
for(i=x;i<a;i++)
if(a%i==0)
count+=f(a/i,i);
if(i==a)
return count;
else
return 0;
}
void main()
{
int n,a;
scanf(\"%d\",&n);
for(;n>0;n--)
{
scanf(\"%d\",&a);
if(a==1||a==2)
printf(\"1\
\");
else
printf(\"%d\
\",f(a,2));
}
}
"
}
然后碟摆,這個任務的目的就是求出針對某一段代碼最相似的代碼段。以取top 2為例:輸出的樣例如下:
{"index": "0", "answers": ["3", "2"]}
{"index": "1", "answers": ["0", "4"]}
{"index": "2", "answers": ["0", "1"]}
{"index": "4", "answers": ["1", "5"]}
{"index": "3", "answers": ["4", "2"]}
{"index": "5", "answers": ["4", "3"]}
也就是說叨橱,針對于代碼index 0, 最相似的代碼段是 index 3和2.
index 3是這樣的:
void qut(int a,int b); //????
int num=0; //?????????
int main()
{
int i,n,g[1000]; //?????????
cin>>n;
for(i=0;i<n;i++) //??????
cin>>g[i];
for(i=0;i<n;i++)
{
qut(g[i],1); //????
cout<<num<<endl;
num=0;
}
return 0;
}
void qut(int a,int b)
{
int i;
if (a>=b)
{
num++;
if (b==1)
b++;
for (i=b;i<=a;i++)
{
if (a%i==0)
{
qut(a/i,i); //??a%i==0,??
}
}
}
}
問題:缺陷檢測
缺陷檢測的數(shù)據(jù)集非常簡單粗暴典蜕,就是一段打標的代碼,標識是不是有漏洞雏逾。
我們看個有漏洞的例子:
{
"project":"FFmpeg",
"commit_id":"aba232cfa9b193604ed98f3fa505378d006b1b3b",
"target":1,
"func":"static int r3d_read_rdvo(AVFormatContext *s, Atom *atom)
{
R3DContext *r3d = s->priv_data;
AVStream *st = s->streams[0];
int i;
r3d->video_offsets_count = (atom->size - 8) / 4;
r3d->video_offsets = av_malloc(atom->size);
if (!r3d->video_offsets)
return AVERROR(ENOMEM);
for (i = 0; i < r3d->video_offsets_count; i++) {
r3d->video_offsets[i] = avio_rb32(s->pb);
if (!r3d->video_offsets[i]) {
r3d->video_offsets_count = i;
break;
}
av_dlog(s, \"video offset %d: %#x\
\", i, r3d->video_offsets[i]);
}
if (st->r_frame_rate.num)
st->duration = av_rescale_q(r3d->video_offsets_count,
(AVRational){st->r_frame_rate.den,
st->r_frame_rate.num},
st->time_base);
av_dlog(s, \"duration %\"PRId64\"\
\", st->duration);
return 0;
}
",
"idx":5
}
信息就這么多嘉裤,至于哪行是什么問題郑临,訓練集中沒有栖博。
當然,數(shù)據(jù)集里大部分還是沒有漏洞的厢洞,比如第一條:
{"project": "FFmpeg", "commit_id": "973b1a6b9070e2bf17d17568cbaf4043ce931f51", "target": 0, "func": "static av_cold int vdadec_init(AVCodecContext *avctx)\n\n{\n\n VDADecoderContext *ctx = avctx->priv_data;\n\n struct vda_context *vda_ctx = &ctx->vda_ctx;\n\n OSStatus status;\n\n int ret;\n\n\n\n ctx->h264_initialized = 0;\n\n\n\n /* init pix_fmts of codec */\n\n if (!ff_h264_vda_decoder.pix_fmts) {\n\n if (kCFCoreFoundationVersionNumber < kCFCoreFoundationVersionNumber10_7)\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts_prior_10_7;\n\n else\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts;\n\n }\n\n\n\n /* init vda */\n\n memset(vda_ctx, 0, sizeof(struct vda_context));\n\n vda_ctx->width = avctx->width;\n\n vda_ctx->height = avctx->height;\n\n vda_ctx->format = 'avc1';\n\n vda_ctx->use_sync_decoding = 1;\n\n vda_ctx->use_ref_buffer = 1;\n\n ctx->pix_fmt = avctx->get_format(avctx, avctx->codec->pix_fmts);\n\n switch (ctx->pix_fmt) {\n\n case AV_PIX_FMT_UYVY422:\n\n vda_ctx->cv_pix_fmt_type = '2vuy';\n\n break;\n\n case AV_PIX_FMT_YUYV422:\n\n vda_ctx->cv_pix_fmt_type = 'yuvs';\n\n break;\n\n case AV_PIX_FMT_NV12:\n\n vda_ctx->cv_pix_fmt_type = '420v';\n\n break;\n\n case AV_PIX_FMT_YUV420P:\n\n vda_ctx->cv_pix_fmt_type = 'y420';\n\n break;\n\n default:\n\n av_log(avctx, AV_LOG_ERROR, \"Unsupported pixel format: %d\\n\", avctx->pix_fmt);\n\n goto failed;\n\n }\n\n status = ff_vda_create_decoder(vda_ctx,\n\n avctx->extradata, avctx->extradata_size);\n\n if (status != kVDADecoderNoErr) {\n\n av_log(avctx, AV_LOG_ERROR,\n\n \"Failed to init VDA decoder: %d.\\n\", status);\n\n goto failed;\n\n }\n\n avctx->hwaccel_context = vda_ctx;\n\n\n\n /* changes callback functions */\n\n avctx->get_format = get_format;\n\n avctx->get_buffer2 = get_buffer2;\n\n#if FF_API_GET_BUFFER\n\n // force the old get_buffer to be empty\n\n avctx->get_buffer = NULL;\n\n#endif\n\n\n\n /* init H.264 decoder */\n\n ret = ff_h264_decoder.init(avctx);\n\n if (ret < 0) {\n\n av_log(avctx, AV_LOG_ERROR, \"Failed to open H.264 decoder.\\n\");\n\n goto failed;\n\n }\n\n ctx->h264_initialized = 1;\n\n\n\n return 0;\n\n\n\nfailed:\n\n vdadec_close(avctx);\n\n return -1;\n\n}\n", "idx": 0}
推理搞起來也是十分省事了仇让,就是對應每個index給個0或1的結果:
0 0
1 1
2 1
3 0
4 0
問題:代碼自動修復
有了識別代碼漏洞的,更進一步就是學習自動修復代碼的了躺翻。
代碼自動修復的題目也很簡單丧叽,一段是有bug的代碼,另一段是修復之后的代碼公你。
我們來看一個例子:
有bug的代碼是這樣的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; }
修復之后是這樣子的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; }
也真難為算法了踊淳,人看起來都有點費事。
問題:代碼互譯
比如實現(xiàn)C#語言和Java語言的互譯。我們只要有一系列代碼的C#寫法和Java寫法迂尝,就可以進行學習進行互譯脱茉。
我們來看一對例子。
先看C#代碼:
public virtual ListSpeechSynthesisTasksResponse ListSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request){
var options = new InvokeOptions();
options.RequestMarshaller = ListSpeechSynthesisTasksRequestMarshaller.Instance;
options.ResponseUnmarshaller = ListSpeechSynthesisTasksResponseUnmarshaller.Instance;
return Invoke<ListSpeechSynthesisTasksResponse>(request, options);
}
對應的Java
public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {
request = beforeClientExecution(request);
return executeListSpeechSynthesisTasks(request);
}
問題:給代碼寫注釋
在訓練素材中垄开,有代碼和注釋琴许,這個任務的目的為新代碼寫注釋。評價指標是對于生成的注釋的語言準確度溉躲。
這個我們使用CodeSearchNet數(shù)據(jù)集榜田。
這個數(shù)據(jù)集中的每條記錄的格式如下:
- repo: 倉庫名
- path: 文件名
- func_name: 函數(shù)或方法名
- original_string: 未經處理的源字符串
- language: 編程語言
- code/function: 代碼信息
- code_tokens/function_tokens: 分詞之后的代碼結果
- docstring: 注釋字符串信息
- docstring_tokens: docstring分詞之后的結果
- url: 自然語言的唯一標識號
- idx: 代碼段的唯一標識號
我們來看個例子:
{"repo": "ciena-blueplanet/bunsen-core", "path": "src/reducer.js", "func_name": "", "original_string": "function
(state, action) {\n return _.defaults({\n isValidating: action.isValidating,\n lastAction: IS_VALIDA
TING\n }, state)\n }", "language": "javascript", "code": "function (state, action) {\n return _.defaults({
\n isValidating: action.isValidating,\n lastAction: IS_VALIDATING\n }, state)\n }", "code_tokens":
["function", "(", "state", ",", "action", ")", "{", "return", "_", ".", "defaults", "(", "{", "isValidating", ":"
, "action", ".", "isValidating", ",", "lastAction", ":", "IS_VALIDATING", "}", ",", "state", ")", "}"], "docstrin
g": "Update is validating result\n@param {State} state - state to update\n@param {Action} action - action\n@retur
ns {State} - updated state", "docstring_tokens": ["Update", "is", "validating", "result"], "sha": "993c67e314e2b7
5003a1ff4c2f0cb667715562b2", "url": "https://github.com/ciena-blueplanet/bunsen-core/blob/993c67e314e2b75003a1ff4
c2f0cb667715562b2/src/reducer.js#L394-L399", "partition": "train"}
對于生成的自然語言,我們采用《ORANGE: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation 》論文的方法進行評分锻梳。
問題:為自然語言文本匹配最合適的代碼段
我們仍然使用上一節(jié)的CodeSearchNet數(shù)據(jù)集箭券。
這個搜索的結果類似于下面這樣:
{"url": "url0", "answers": [10,11,12,13,14]}
{"url": "url1", "answers": [10,12,11,13,14]}
{"url": "url2", "answers": [13,11,12,10,14]}
{"url": "url3", "answers": [10,14,12,13,11]}
{"url": "url4", "answers": [10,11,12,13,14]}
配上UI,大致實現(xiàn)的效果是這樣的:
或者是這樣:
問題:根據(jù)自然語言生成代碼
這是終極任務疑枯,就是根據(jù)一段文本描述硬生生地生成一段代碼出來邦鲫。
格式非常簡單,就一段代碼和一段文本神汹。
我們來看個訓練樣本的例子:
{"code": "void function ( Binder arg0 ) { EventBus loc0 = new EventBus ( ) ; AmbariEventPublisher loc1 = new AmbariEventPublisher ( ) ; repla
ceEventBus ( AmbariEventPublisher . class , loc1 , loc0 ) ; arg0 . bind ( AmbariEventPublisher . class ) . toInstance ( loc1 ) ; }", "nl": "force the eventb us from ambarievent publisher to be serialand synchronous . concode_field_sep PlaceHolder placeHolder concode_field_sep void registerAlertListeners concode_elem_sep EventBus synchronizeAlertEventPublisher concode_elem_sep void replaceEventBus concode_elem_sep void registerAmbariListeners"}
這NL部分有點亂啊庆捺,沒辦法,為了增加數(shù)據(jù)量屁魏,沒有那么多人手打精確的標滔以。
我們再看一個:
{"code": "byte [ ] function ( Class < ? > arg0 , Configuration arg1 ) { return AuthenticationTokenSerializer . serialize ( org . apache . acc
umulo . core . client . mapreduce . lib . impl . ConfiguratorBase . getAuthenticationToken ( arg0 , arg1 ) ) ; }", "nl": "do n't use this . n
o , really , do n't use this . you already have an authenticationtoken with org.apache.accumulo.core.client.mapreduce.lib.impl.configuratorba
se #getauthenticationtoken class , configuration . you do n't need to construct it yourself . gets the password from the configuration . warn
ing : the password is stored in the configuration and shared with all mapreduce tasks ; it is base64 encoded to provide a charset safe conver
sion to a string , and is not intended to be secure . concode_field_sep PlaceHolder placeHolder concode_field_sep String getPrincipal concode
_elem_sep void setLogLevel concode_elem_sep Level getLogLevel concode_elem_sep Boolean isConnectorInfoSet concode_elem_sep String getTokenCla
ss concode_elem_sep void setZooKeeperInstance concode_elem_sep void setMockInstance concode_elem_sep Instance getInstance concode_elem_sep St
ring enumToConfKey concode_elem_sep void setConnectorInfo"}
是不是質量也沒好到哪兒去?這就是CONCODE數(shù)據(jù)集的樣子氓拼。
解法:基于大規(guī)模預訓練模型的多任務學習
402年前你画,當努爾哈赤面臨明朝多路大軍的圍困的時候,采取了“憑你幾路來桃漾,我只一路去”的戰(zhàn)術贏得了薩爾滸之戰(zhàn)的立國之戰(zhàn)坏匪。
我們同樣學習古人的智慧,任你數(shù)據(jù)集千變萬化撬统,我們的工具就只用一個 - 大規(guī)模預訓練模型适滓。
下面是預訓練模型的簡要發(fā)展史:
以開頭我們展示的微軟的codebert模型為例,我們要處理上面最復雜的代碼生成任務恋追,只要一條命令就可以搞定:
python -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU run.py \
--data_dir=$DATADIR \
--langs=$LANG \
--output_dir=$OUTPUTDIR \
--pretrain_dir=$PRETRAINDIR \
--log_file=$LOGFILE \
--model_type=gpt2 \
--block_size=512 \
--do_train \
--node_index 0 \
--gpu_per_node $PER_NODE_GPU \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--evaluate_during_training \
--per_gpu_train_batch_size=6 \
--per_gpu_eval_batch_size=12 \
--gradient_accumulation_steps=2 \
--num_train_epochs=30 \
--logging_steps=100 \
--save_steps=5000 \
--overwrite_output_dir \
--seed=42
如果使用兩張2 NVIDIA P100 GPU卡的話凭迹,22小時左右就可以訓練完。
推理呢苦囱,也是一條語句就搞定:
python -u run.py \
--data_dir=$DATADIR \
--langs=$LANG \
--output_dir=$OUTPUTDIR \
--pretrain_dir=$PRETRAINDIR \
--log_file=$LOGFILE \
--model_type=gpt2 \
--block_size=512 \
--do_infer \
--logging_steps=100 \
--seed=42
只用一張P100卡嗅绸,大約40分鐘就可以搞定。
有了上面的基礎撕彤,我們就可以去打比賽啦鱼鸠。上面介紹的數(shù)據(jù)集,全都是比賽的賽題:
上面提到的數(shù)據(jù)集,可以在https://github.com/microsoft/CodeXGLUE下載到蚀狰。
歡迎來到代碼智能的世界漆弄!
附錄:快速上手指南
放翁云:紙上得來終覺淺,絕知此事要躬行造锅。
下面我們就落地下撼唾,將代碼智能模型的訓練和推理跑起來~~~
- 第一步:安裝transformers框架,因為codebert是基于這個框架的:
pip install transformers --user
- 第二步:安裝PyTorch或者Tensorflow作為Transformers的后端哥蔚,以2021年7月5日這個時間點倒谷,需要的PyTorch版本至少是1.5.0以上。驅動能搞定的話糙箍,索性就安裝最新的吧:
pip install torch torchvision torchtext torchaudio --user
- 第三步渤愁,下載微軟的數(shù)據(jù)集
git clone https://github.com/microsoft/CodeXGLUE
- 第四步,我們先玩玩BigCloneBench吧
到Code-Code/Clone-detection-BigCloneBench/code目錄下深夯,運行:
python run.py --output_dir=./saved_models --model_type=roberta --config_name=microsoft/codebert-base --model_name_or_path=microsoft/codebert-base --tokenizer_name=roberta-base --do_train --train_data_file=../dataset/train.txt --eval_data_file=../dataset/valid.txt --test_data_file=../dataset/test.txt --epoch 2 --block_size 400 --train_batch_size 16 --eval_batch_size 32 --learning_rate 5e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee train.log
然后訓練就運行起來了:
07/05/2021 16:29:24 - INFO - __main__ - ***** Running training *****
07/05/2021 16:29:24 - INFO - __main__ - Num examples = 90102
07/05/2021 16:29:24 - INFO - __main__ - Num Epochs = 2
07/05/2021 16:29:24 - INFO - __main__ - Instantaneous batch size per GPU = 8
07/05/2021 16:29:24 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 16
07/05/2021 16:29:24 - INFO - __main__ - Gradient Accumulation steps = 1
07/05/2021 16:29:24 - INFO - __main__ - Total optimization steps = 11264
在兩張V100卡大約需要訓練40分鐘左右抖格。
訓練之后是驗證,會將目前最好的結果保存到checkpoint中以備推理時使用
07/05/2021 17:10:04 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:10<00:00, 2785.61it/s]
07/05/2021 17:10:04 - INFO - __main__ - Num examples = 41541
07/05/2021 17:10:04 - INFO - __main__ - Batch size = 32
07/05/2021 17:16:05 - INFO - __main__ - ***** Eval results *****
07/05/2021 17:16:05 - INFO - __main__ - eval_f1 = 0.9531
07/05/2021 17:16:05 - INFO - __main__ - eval_precision = 0.9579
07/05/2021 17:16:05 - INFO - __main__ - eval_recall = 0.9484
07/05/2021 17:16:05 - INFO - __main__ - eval_threshold = 0.97
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:06 - INFO - __main__ - Best f1:0.9531
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:08 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
一次訓練兩輪咕晋,第二輪效果提升到0.97多:
07/05/2021 17:56:43 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:12<00:00, 3535.62it/s]
07/05/2021 17:56:43 - INFO - __main__ - Num examples = 41541
07/05/2021 17:56:43 - INFO - __main__ - Batch size = 32
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
07/05/2021 18:02:44 - INFO - __main__ - ***** Eval results *****
07/05/2021 18:02:44 - INFO - __main__ - eval_f1 = 0.9701
07/05/2021 18:02:44 - INFO - __main__ - eval_precision = 0.9772
07/05/2021 18:02:44 - INFO - __main__ - eval_recall = 0.9633
07/05/2021 18:02:44 - INFO - __main__ - eval_threshold = 0.97
07/05/2021 18:02:45 - INFO - __main__ - ********************
07/05/2021 18:02:45 - INFO - __main__ - Best f1:0.9701
07/05/2021 18:02:45 - INFO - __main__ - ********************
07/05/2021 18:02:47 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
然后我們用訓好的模型進行推理吧:
python run.py \
--output_dir=./saved_models \
--model_type=roberta \
--config_name=microsoft/codebert-base \
--model_name_or_path=microsoft/codebert-base \
--tokenizer_name=roberta-base \
--do_eval \
--do_test \
--train_data_file=../dataset/train.txt \
--eval_data_file=../dataset/valid.txt \
--test_data_file=../dataset/test.txt \
--epoch 2 \
--block_size 400 \
--train_batch_size 16 \
--eval_batch_size 32 \
--learning_rate 5e-5 \
--max_grad_norm 1.0 \
--evaluate_during_training \
--seed 123456 2>&1| tee test.log
最后我們運行evaluator.py來查看測試結果:
python ../evaluator/evaluator.py -a ../dataset/test.txt -p saved_models/predictions.txt
輸出如下:
{'Recall': 0.9677421599288263, 'Prediction': 0.9557057904236594, 'F1': 0.9616080550111168}
準確率0.956, 召回率0.968雹拄,還不錯~
跟CodeXGLUE的排行榜比一比:
跟榜上的CodeBert的結果基本一致
GraphCodeBert
要提升性能,我們可以用GraphCodeBert來替換CodeBert.
我們先下載GraphCodeBert的代碼:
git clone https://github.com/microsoft/CodeBERT
然后轉到GraphCodeBERT/clonedetection目錄掌呜,解壓dataset.zip:
unzip dataset.zip
然后就可以像訓練codebert一樣訓練graphcodebert了:
mkdir saved_models
python run.py \
--output_dir=saved_models \
--config_name=microsoft/graphcodebert-base \
--model_name_or_path=microsoft/graphcodebert-base \
--tokenizer_name=microsoft/graphcodebert-base \
--do_train \
--train_data_file=dataset/train.txt \
--eval_data_file=dataset/valid.txt \
--test_data_file=dataset/test.txt \
--epoch 1 \
--code_length 512 \
--data_flow_length 128 \
--train_batch_size 16 \
--eval_batch_size 32 \
--learning_rate 2e-5 \
--max_grad_norm 1.0 \
--evaluate_during_training \
--seed 123456 2>&1| tee saved_models/train.log
上面的參數(shù)是按4個V100 GPU來調的滓玖,如果只有兩塊V100,可以將--code_length改成256.
CodeBert 40分鐘左右一輪质蕉,GraphCodeBert大約需要6個半小時一輪势篡。
然后我們進行推理:
python run.py --output_dir=saved_models --config_name=microsoft/graphcodebert-base --model_name_or_path=microsoft/graphcodebert-base --tokenizer_name=microsoft/graphcodebert-base --do_eval --do_test --train_data_file=dataset/train.txt --eval_data_file=dataset/valid.txt --test_data_file=dataset/test.txt --epoch 1 --code_length 256 --data_flow_length 128 --train_batch_size 16 --eval_batch_size 32 --learning_rate 2e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee saved_models/test.log
最后我們解讀一下結果吧:
python evaluator/evaluator.py -a dataset/test.txt -p saved_models/predictions.txt 2>&1| tee saved_models/score.log
結果如下:
{'Recall': 0.9589415798936043, 'Prediction': 0.962620653900429, 'F1': 0.9607703728051462}