現(xiàn)象
使用 Hugging Face Trainer
在單機(jī)多卡環(huán)境下對(duì) LLAMA2-7B 進(jìn)行 LoRA finetuning 時(shí)力图,在第一次保存 checkpoint 時(shí),程序 assert out毡庆,關(guān)鍵 error trace log 如下
[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support
[1711969312.876606027] rank4.python: Reading from remote process' memory failed. Disabling CMA support
[1711969312.876618213] rank5.python: Reading from remote process' memory failed. Disabling CMA support
rank5: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen
rank2: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen
根因
順藤摸瓜
-
accelerate
的 FSDP 在保存 checkpoint 時(shí),會(huì)調(diào)用其自己的save_fsdp_optimizer
方法烙如,該方法首先調(diào)用了 PyTorch 的FSDP.optim_state_dict
方法以獲取并確保每個(gè)rank
上都有其需要的最新的optimizer
的state_dict
么抗,然后根據(jù)相應(yīng)的fsdp_state_dict_type
設(shè)置將其保存。Assert out 就發(fā)生在FSDP.optim_state_dict
調(diào)用中厅翔。 - 找到 PyTorch
FSDP.optim_state_dict
的實(shí)現(xiàn)乖坠,發(fā)現(xiàn) assert out 發(fā)生在調(diào)用FullyShardedDataParallel._optim_state_dict_impl
時(shí)。 - 再轉(zhuǎn)至
FullyShardedDataParallel._optim_state_dict_impl
的實(shí)現(xiàn)刀闷,發(fā)現(xiàn) assert out 發(fā)生在其調(diào)用_optim_state_dict
時(shí)熊泵。 - 繼續(xù)轉(zhuǎn)至
_optim_state_dict
的實(shí)現(xiàn),發(fā)現(xiàn) assert out 發(fā)生在其調(diào)用_map_param_key_to_optim_keys
時(shí)甸昏。 - 繼續(xù)轉(zhuǎn)至
_map_param_key_to_optim_keys
的實(shí)現(xiàn)顽分,發(fā)現(xiàn) assert out 發(fā)生在調(diào)用dist.broadcast_object_list
。
至此施蜜,瓜已得卒蘸,需分析 dist.broadcast_object_list
。
抽絲剝繭
-
首先需要分析
dist.broadcast_object_list
broadcast 了啥翻默,看下代碼:key_obj_list: List[Optional[List[_OptimStateKey]]] = ( [all_optim_state_keys] if rank == 0 else [None]) dist.broadcast_object_list(key_obj_list, src=0, group=group)
由代碼可知缸沃,broad cast 的是一堆
_OptimStateKey
,而_OptimStateKey
是一個(gè)字符串組成的tuple
修械,每個(gè)字符串里放的是optimizer 中每個(gè)模型參數(shù)的狀態(tài)(即 momentum, variance 等)的 unflat 的 fully qualified name趾牧。這些東西是在 CPU 上的,需要由 rank 0 廣播到其余 rank肯污,以對(duì)齊參數(shù)名翘单。 -
好了,知道數(shù)據(jù)是在 CPU 上的蹦渣,那就知道為啥在 checkpointing 之前是好的了哄芜,因?yàn)榇饲岸际巧婕暗?GPU 上 tensor 的 collective communication,那塊看來是好的柬唯。Intel CPU 和 GPU 平臺(tái)的 collective communication 后端走的是 oneCCL认臊,其中 CPU 上數(shù)據(jù)的單機(jī)多卡 broadcast 走的是什么方案呢?再去看一眼 log:
[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support
從 log 中大致可以猜出通信方案是 shared memory(SHM)锄奢,不然不會(huì)有
Reading from remote process' memory failed
失晴,這很合理冤议,因?yàn)槭菃螜C(jī)多卡;且采用的 SHM 方案是 CMA(Cross Memory Attach)师坎,這是 Linux 內(nèi)核實(shí)現(xiàn)的一種 kernel assisted zero copy SHM 機(jī)制,示意如下(摘自此論文):
那就是 CMA 出了啥問題堪滨。以上只是猜想胯陋,猜想只是起點(diǎn),總是要實(shí)證袱箱。既然 oneCCL 是集合通信后端遏乔,我們就要分析一下它。從這兒可以知道:oneCCL 有兩個(gè) transport 后端发笔, 即 OFI 和 MPI盟萨。從這兒又可以知道,intel MPI 的實(shí)現(xiàn)現(xiàn)在也基于 OFI 了了讨,而 OFI 的實(shí)現(xiàn)是 libfabric捻激, 如下:
那么我們就去 libfabric 的代碼庫中找找有沒有以下 log 相關(guān)的代碼:[1711969312.876608051] rank2.python: Reading from remote process' memory failed. Disabling CMA support rank2: Assertion failure at psm3/ptl_am/ptl.c:196: nbytes == req->req_data.recv_msglen
然后就從這里到了如下代碼:
size_t nbytes = psm3_cma_get(pid, (void *)req->rts_sbuf, req->req_data.buf, req->req_data.recv_msglen); if (nbytes == -1) { ptl->psmi_kassist_mode = PSMI_KASSIST_OFF; _HFI_ERROR("Reading from remote process' memory failed. Disabling CMA support\n"); } else { psmi_assert_always(nbytes == req->req_data.recv_msglen); cma_succeed = 1; } psmi_assert_always(nbytes == req->req_data.recv_msglen);
從代碼可以看到,
psm3_cma_get
調(diào)用返回錯(cuò)誤,首先觸發(fā)了Reading from remote process' memory failed. Disabling CMA support\n
錯(cuò)誤信息打印前计,隨后又通過psmi_assert_always
assert out 了胞谭,與我們看到的 log 完全一樣。至此男杈,絲抽完了丈屹,已經(jīng)找到問題發(fā)生的地方了。 轉(zhuǎn)到
psm3_cma_get
的實(shí)現(xiàn)代碼伶棒,可知是process_vm_readv
返回錯(cuò)誤了旺垒。查看process_vm_readv
的手冊(cè)可以看到如下表述:
Permission to read from or write to another process is governed by a ptrace access mode PTRACE_MODE_ATTACH_REALCREDS check; see ptrace(2).
因?yàn)?CMA 涉及到進(jìn)程訪問別的進(jìn)程的內(nèi)存,一個(gè)有可能的合理懷疑就是當(dāng)前進(jìn)程沒有權(quán)限訪問另一個(gè)進(jìn)程的內(nèi)存肤无,這個(gè)也通過 CMA patch the commit message 得到了印證先蒋,其中寫道:
Currently mem_read allows only processes who are currently ptrace'ing the target and are still able to ptrace the target to read from the target.
那就上谷歌搜一下 cma ptrace
看下 CMA
需要怎樣的 ptrace
設(shè)置,果然首個(gè)鏈接就找到了答案舅锄。
Same issue... Try the following:
https://groups.io/g/OpenHPC-users/topic/openmpi_and_shared_memory/16489081?p=,,,20,0,0,0::recentpostdate%2Fsticky,,,20,2,0,16489081
$ echo 0 > /proc/sys/kernel/yama/ptrace_scope
or
$ sudo echo 0 > /proc/sys/kernel/yama/ptrace_scope
or
$ echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope
嘗試一下鞭达,搞定!
解法
$ echo 0 > /proc/sys/kernel/yama/ptrace_scope
ptrace_scope
說明見此皇忿。所測(cè)系統(tǒng)之前 ptrace_scope
值是 1
畴蹭。也就是說,非 rank 0 的進(jìn)程要讀 rank 0 的 SHM鳍烁,必須滿足 rank 0 是它們的后代進(jìn)程才行叨襟,這顯然不符合當(dāng)前工作負(fù)載的實(shí)情。所以需要設(shè)成 0
幔荒,以使得主要這些進(jìn)程是同一個(gè) uid
下的就可以讀 SHM糊闽。
最后的話
當(dāng)前來看梳玫,結(jié)果很重要;長遠(yuǎn)來看右犹,過程很重要提澎!這是工程的真諦。