1. 環(huán)境準(zhǔn)備
pip install open-retrievals
2. 使用M3E模型
from retrievals import AutoModelForEmbedding
embedder = AutoModelForEmbedding.from_pretrained('moka-ai/m3e-base', pooling_method='mean')
embedder
[圖片上傳失敗...(image-d940b0-1726231154763)]
sentences = [
'* Moka 此文本嵌入模型由 MokaAI 訓(xùn)練并開源谅海,訓(xùn)練腳本使用 uniem',
'* Massive 此文本嵌入模型通過**千萬級**的中文句對數(shù)據(jù)集進(jìn)行訓(xùn)練',
'* Mixed 此文本嵌入模型支持中英雙語的同質(zhì)文本相似度計算取劫,異質(zhì)文本檢索等功能,未來還會支持代碼檢索莱睁,ALL in one'
]
embeddings = embedder.encode(sentences)
for sentence, embedding in zip(sentences, embeddings):
print("Sentence:", sentence)
print("Embedding:", embedding)
print("")
[圖片上傳失敗...(image-a421df-1726231154763)]
3. deepspeed 微調(diào)M3E模型
數(shù)據(jù)仍然采用之前介紹的t2-ranking數(shù)據(jù)集
- deepspeed配置保存為
ds_zero2_no_offload.json
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 100,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1e-10
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
這里稍微修改了open-retrievals這里的代碼痕支,主要是修改了導(dǎo)入為包的導(dǎo)入颁虐,而不是相對引用蛮原。保存文件為embed.py
"""Embedding fine tune pipeline"""
import logging
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
from retrievals import (
EncodeCollator,
EncodeDataset,
PairCollator,
RetrievalTrainDataset,
TripletCollator,
)
from retrievals.losses import AutoLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.models.embedding_auto import AutoModelForEmbedding
from retrievals.trainer import RetrievalTrainer
# os.environ["WANDB_LOG_MODEL"] = "false"
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})
lora_path: Optional[str] = field(default=None, metadata={'help': "Lora adapter save path"})
@dataclass
class DataArguments:
data_name_or_path: str = field(default=None, metadata={"help": "Path to train data"})
train_group_size: int = field(default=2)
unfold_each_positive: bool = field(default=False)
query_max_length: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
document_max_length: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})
query_key: str = field(default=None)
positive_key: str = field(default='positive')
negative_key: str = field(default='negative')
is_query: bool = field(default=False)
encoding_save_file: str = field(default='embed.pkl')
def __post_init__(self):
# self.data_name_or_path = 'json'
self.dataset_split = 'train'
self.dataset_language = 'default'
if self.data_name_or_path is not None:
if not os.path.isfile(self.data_name_or_path) and not os.path.isdir(self.data_name_or_path):
info = self.data_name_or_path.split('/')
self.dataset_split = info[-1] if len(info) == 3 else 'train'
self.data_name_or_path = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
self.dataset_language = 'default'
if ':' in self.data_name_or_path:
self.data_name_or_path, self.dataset_language = self.data_name_or_path.split(':')
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
train_type: str = field(default='pairwise', metadata={'help': "train type of point, pair, or list"})
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=0.02)
fix_position_embedding: bool = field(
default=False, metadata={"help": "Freeze the parameters of position embeddings"}
)
pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
normalized: bool = field(default=True)
loss_fn: str = field(default='infonce')
use_inbatch_negative: bool = field(default=True, metadata={"help": "use documents in the same batch as negatives"})
remove_unused_columns: bool = field(default=False)
use_lora: bool = field(default=False)
use_bnb_config: bool = field(default=False)
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
report_to: Optional[List[str]] = field(
default="none", metadata={"help": "The list of integrations to report the results and logs to."}
)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, RetrieverTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
set_seed(training_args.seed)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
if training_args.use_bnb_config:
from transformers import BitsAndBytesConfig
logger.info('Use quantization bnb config')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
else:
quantization_config = None
if training_args.do_train:
model = AutoModelForEmbedding.from_pretrained(
model_name_or_path=model_args.model_name_or_path,
pooling_method=training_args.pooling_method,
use_lora=training_args.use_lora,
quantization_config=quantization_config,
)
loss_fn = AutoLoss(
loss_name=training_args.loss_fn,
loss_kwargs={
'use_inbatch_negative': training_args.use_inbatch_negative,
'temperature': training_args.temperature,
},
)
model = model.set_train_type(
"pairwise",
loss_fn=loss_fn,
)
train_dataset = RetrievalTrainDataset(
args=data_args,
tokenizer=tokenizer,
positive_key=data_args.positive_key,
negative_key=data_args.negative_key,
)
logger.info(f"Total training examples: {len(train_dataset)}")
trainer = RetrievalTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=TripletCollator(
tokenizer,
query_max_length=data_args.query_max_length,
document_max_length=data_args.document_max_length,
positive_key=data_args.positive_key,
negative_key=data_args.negative_key,
),
)
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
trainer.train()
# trainer.save_model(training_args.output_dir)
model.save_pretrained(training_args.output_dir)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if training_args.do_encode:
model = AutoModelForEmbedding.from_pretrained(
model_name_or_path=model_args.model_name_or_path,
pooling_method=training_args.pooling_method,
use_lora=training_args.use_lora,
quantization_config=quantization_config,
lora_path=model_args.lora_path,
)
max_length = data_args.query_max_length if data_args.is_query else data_args.document_max_length
logger.info(f'Encoding will be saved in {training_args.output_dir}')
encode_dataset = EncodeDataset(args=data_args, tokenizer=tokenizer, max_length=max_length, text_key='text')
logger.info(f"Number of train samples: {len(encode_dataset)}, max_length: {max_length}")
encode_loader = DataLoader(
encode_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=EncodeCollator(tokenizer, max_length=max_length, padding='max_length'),
shuffle=False,
drop_last=False,
num_workers=training_args.dataloader_num_workers,
)
embeddings = model.encode(encode_loader, show_progress_bar=True, convert_to_numpy=True)
lookup_indices = list(range(len(encode_dataset)))
with open(os.path.join(training_args.output_dir, data_args.encoding_save_file), 'wb') as f:
pickle.dump((embeddings, lookup_indices), f)
if __name__ == "__main__":
main()
- 最終調(diào)用文件
shell run.sh
MODEL_NAME="moka-ai/m3e-base"
TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"
# loss_fn: infonce, simcse
deepspeed -m --include localhost:0 embed.py \
--deepspeed ds_zero2_no_offload.json \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--do_train \
--data_name_or_path $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--pooling_method mean \
--loss_fn infonce \
--use_lora False \
--query_instruction "" \
--document_instruction "" \
--learning_rate 3e-5 \
--fp16 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--dataloader_drop_last True \
--query_max_length 64 \
--document_max_length 256 \
--train_group_size 4 \
--logging_steps 100 \
--temperature 0.02 \
--save_total_limit 1 \
--use_inbatch_negative false
[圖片上傳失敗...(image-b66e43-1726231154763)]
4. 測試
微調(diào)前性能 c-mteb t2-ranking score
[圖片上傳失敗...(image-7db424-1726231154763)]
微調(diào)后性能
[圖片上傳失敗...(image-4a3e4b-1726231154763)]
采用infoNCE損失函數(shù)卧须,沒有加in-batch negative,而關(guān)注的是困難負(fù)樣本儒陨,經(jīng)過微調(diào)map從0.654提升至0.692花嘶,mrr從0.754提升至0.805
歡迎關(guān)注最新的更新https://github.com/LongxingTan/open-retrievals