動手學(xué)習(xí)RAG: moka-ai/m3e 模型微調(diào)deepspeed與對比學(xué)習(xí)

1. 環(huán)境準(zhǔn)備

pip install open-retrievals

2. 使用M3E模型

from retrievals import AutoModelForEmbedding

embedder = AutoModelForEmbedding.from_pretrained('moka-ai/m3e-base', pooling_method='mean')
embedder

[圖片上傳失敗...(image-d940b0-1726231154763)]

sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 訓(xùn)練并開源谅海,訓(xùn)練腳本使用 uniem',
    '* Massive 此文本嵌入模型通過**千萬級**的中文句對數(shù)據(jù)集進(jìn)行訓(xùn)練',
    '* Mixed 此文本嵌入模型支持中英雙語的同質(zhì)文本相似度計算取劫,異質(zhì)文本檢索等功能,未來還會支持代碼檢索莱睁,ALL in one'
]

embeddings = embedder.encode(sentences)

for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

[圖片上傳失敗...(image-a421df-1726231154763)]

3. deepspeed 微調(diào)M3E模型

數(shù)據(jù)仍然采用之前介紹的t2-ranking數(shù)據(jù)集

  • deepspeed配置保存為 ds_zero2_no_offload.json
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 100,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1e-10
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 1e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 1e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

這里稍微修改了open-retrievals這里的代碼痕支,主要是修改了導(dǎo)入為包的導(dǎo)入颁虐,而不是相對引用蛮原。保存文件為embed.py

"""Embedding fine tune pipeline"""

import logging
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

from retrievals import (
    EncodeCollator,
    EncodeDataset,
    PairCollator,
    RetrievalTrainDataset,
    TripletCollator,
)
from retrievals.losses import AutoLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.models.embedding_auto import AutoModelForEmbedding
from retrievals.trainer import RetrievalTrainer

# os.environ["WANDB_LOG_MODEL"] = "false"
logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})
    lora_path: Optional[str] = field(default=None, metadata={'help': "Lora adapter save path"})


@dataclass
class DataArguments:
    data_name_or_path: str = field(default=None, metadata={"help": "Path to train data"})
    train_group_size: int = field(default=2)
    unfold_each_positive: bool = field(default=False)
    query_max_length: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    document_max_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
    document_instruction: str = field(default=None, metadata={"help": "instruction for document"})
    query_key: str = field(default=None)
    positive_key: str = field(default='positive')
    negative_key: str = field(default='negative')
    is_query: bool = field(default=False)
    encoding_save_file: str = field(default='embed.pkl')

    def __post_init__(self):
        # self.data_name_or_path = 'json'
        self.dataset_split = 'train'
        self.dataset_language = 'default'

        if self.data_name_or_path is not None:
            if not os.path.isfile(self.data_name_or_path) and not os.path.isdir(self.data_name_or_path):
                info = self.data_name_or_path.split('/')
                self.dataset_split = info[-1] if len(info) == 3 else 'train'
                self.data_name_or_path = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
                self.dataset_language = 'default'
                if ':' in self.data_name_or_path:
                    self.data_name_or_path, self.dataset_language = self.data_name_or_path.split(':')


@dataclass
class RetrieverTrainingArguments(TrainingArguments):
    train_type: str = field(default='pairwise', metadata={'help': "train type of point, pair, or list"})
    negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    temperature: Optional[float] = field(default=0.02)
    fix_position_embedding: bool = field(
        default=False, metadata={"help": "Freeze the parameters of position embeddings"}
    )
    pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
    normalized: bool = field(default=True)
    loss_fn: str = field(default='infonce')
    use_inbatch_negative: bool = field(default=True, metadata={"help": "use documents in the same batch as negatives"})
    remove_unused_columns: bool = field(default=False)
    use_lora: bool = field(default=False)
    use_bnb_config: bool = field(default=False)
    do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
    report_to: Optional[List[str]] = field(
        default="none", metadata={"help": "The list of integrations to report the results and logs to."}
    )


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, RetrieverTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("Model parameters %s", model_args)
    logger.info("Data parameters %s", data_args)

    set_seed(training_args.seed)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    if training_args.use_bnb_config:
        from transformers import BitsAndBytesConfig

        logger.info('Use quantization bnb config')
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    else:
        quantization_config = None

    if training_args.do_train:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
        )

        loss_fn = AutoLoss(
            loss_name=training_args.loss_fn,
            loss_kwargs={
                'use_inbatch_negative': training_args.use_inbatch_negative,
                'temperature': training_args.temperature,
            },
        )

        model = model.set_train_type(
            "pairwise",
            loss_fn=loss_fn,
        )

        train_dataset = RetrievalTrainDataset(
            args=data_args,
            tokenizer=tokenizer,
            positive_key=data_args.positive_key,
            negative_key=data_args.negative_key,
        )
        logger.info(f"Total training examples: {len(train_dataset)}")

        trainer = RetrievalTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=TripletCollator(
                tokenizer,
                query_max_length=data_args.query_max_length,
                document_max_length=data_args.document_max_length,
                positive_key=data_args.positive_key,
                negative_key=data_args.negative_key,
            ),
        )

        Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

        trainer.train()
        # trainer.save_model(training_args.output_dir)
        model.save_pretrained(training_args.output_dir)

        if trainer.is_world_process_zero():
            tokenizer.save_pretrained(training_args.output_dir)

    if training_args.do_encode:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
            lora_path=model_args.lora_path,
        )

        max_length = data_args.query_max_length if data_args.is_query else data_args.document_max_length
        logger.info(f'Encoding will be saved in {training_args.output_dir}')

        encode_dataset = EncodeDataset(args=data_args, tokenizer=tokenizer, max_length=max_length, text_key='text')
        logger.info(f"Number of train samples: {len(encode_dataset)}, max_length: {max_length}")

        encode_loader = DataLoader(
            encode_dataset,
            batch_size=training_args.per_device_eval_batch_size,
            collate_fn=EncodeCollator(tokenizer, max_length=max_length, padding='max_length'),
            shuffle=False,
            drop_last=False,
            num_workers=training_args.dataloader_num_workers,
        )

        embeddings = model.encode(encode_loader, show_progress_bar=True, convert_to_numpy=True)
        lookup_indices = list(range(len(encode_dataset)))

        with open(os.path.join(training_args.output_dir, data_args.encoding_save_file), 'wb') as f:
            pickle.dump((embeddings, lookup_indices), f)


if __name__ == "__main__":
    main()

  • 最終調(diào)用文件 shell run.sh
MODEL_NAME="moka-ai/m3e-base"

TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"


# loss_fn: infonce, simcse

deepspeed -m --include localhost:0 embed.py \
  --deepspeed ds_zero2_no_offload.json \
  --output_dir $OUTPUT_DIR \
  --overwrite_output_dir \
  --model_name_or_path $MODEL_NAME \
  --do_train \
  --data_name_or_path $TRAIN_DATA \
  --positive_key positive \
  --negative_key negative \
  --pooling_method mean \
  --loss_fn infonce \
  --use_lora False \
  --query_instruction "" \
  --document_instruction "" \
  --learning_rate 3e-5 \
  --fp16 \
  --num_train_epochs 5 \
  --per_device_train_batch_size 32 \
  --dataloader_drop_last True \
  --query_max_length 64 \
  --document_max_length 256 \
  --train_group_size 4 \
  --logging_steps 100 \
  --temperature 0.02 \
  --save_total_limit 1 \
  --use_inbatch_negative false

[圖片上傳失敗...(image-b66e43-1726231154763)]

4. 測試

微調(diào)前性能 c-mteb t2-ranking score

[圖片上傳失敗...(image-7db424-1726231154763)]

微調(diào)后性能

[圖片上傳失敗...(image-4a3e4b-1726231154763)]

采用infoNCE損失函數(shù)卧须,沒有加in-batch negative,而關(guān)注的是困難負(fù)樣本儒陨,經(jīng)過微調(diào)map從0.654提升至0.692花嘶,mrr從0.754提升至0.805

歡迎關(guān)注最新的更新https://github.com/LongxingTan/open-retrievals

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市蹦漠,隨后出現(xiàn)的幾起案子椭员,更是在濱河造成了極大的恐慌,老刑警劉巖笛园,帶你破解...
    沈念sama閱讀 212,454評論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件隘击,死亡現(xiàn)場離奇詭異侍芝,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)埋同,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,553評論 3 385
  • 文/潘曉璐 我一進(jìn)店門州叠,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人凶赁,你說我怎么就攤上這事咧栗。” “怎么了虱肄?”我有些...
    開封第一講書人閱讀 157,921評論 0 348
  • 文/不壞的土叔 我叫張陵致板,是天一觀的道長。 經(jīng)常有香客問我咏窿,道長斟或,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,648評論 1 284
  • 正文 為了忘掉前任集嵌,我火速辦了婚禮缕粹,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘纸淮。我一直安慰自己平斩,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,770評論 6 386
  • 文/花漫 我一把揭開白布咽块。 她就那樣靜靜地躺著绘面,像睡著了一般。 火紅的嫁衣襯著肌膚如雪侈沪。 梳的紋絲不亂的頭發(fā)上揭璃,一...
    開封第一講書人閱讀 49,950評論 1 291
  • 那天,我揣著相機(jī)與錄音亭罪,去河邊找鬼瘦馍。 笑死,一個胖子當(dāng)著我的面吹牛应役,可吹牛的內(nèi)容都是我干的情组。 我是一名探鬼主播,決...
    沈念sama閱讀 39,090評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼箩祥,長吁一口氣:“原來是場噩夢啊……” “哼院崇!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起袍祖,我...
    開封第一講書人閱讀 37,817評論 0 268
  • 序言:老撾萬榮一對情侶失蹤底瓣,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后蕉陋,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體捐凭,經(jīng)...
    沈念sama閱讀 44,275評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡拨扶,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,592評論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了茁肠。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片屈雄。...
    茶點(diǎn)故事閱讀 38,724評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖官套,靈堂內(nèi)的尸體忽然破棺而出酒奶,到底是詐尸還是另有隱情,我是刑警寧澤奶赔,帶...
    沈念sama閱讀 34,409評論 4 333
  • 正文 年R本政府宣布惋嚎,位于F島的核電站,受9級特大地震影響站刑,放射性物質(zhì)發(fā)生泄漏另伍。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,052評論 3 316
  • 文/蒙蒙 一绞旅、第九天 我趴在偏房一處隱蔽的房頂上張望摆尝。 院中可真熱鬧,春花似錦因悲、人聲如沸堕汞。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,815評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽讯检。三九已至,卻和暖如春卫旱,著一層夾襖步出監(jiān)牢的瞬間人灼,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,043評論 1 266
  • 我被黑心中介騙來泰國打工顾翼, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留投放,地道東北人。 一個月前我還...
    沈念sama閱讀 46,503評論 2 361
  • 正文 我出身青樓适贸,卻偏偏與公主長得像灸芳,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子取逾,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,627評論 2 350

推薦閱讀更多精彩內(nèi)容