筆記本型號是Redmi G2021,配置還湊活镊折,5800H胯府,16G, 3050Ti Laptop(4G VRAM),相當一般的配置啦恨胚!不過自從英偉達更新了驅動骂因,可以實現(xiàn)RAM充當顯存,可訓練以及使用的模型就多了很多啦与纽,當然這種類似swap的方式性能還是有點損失的侣签。不過,總比運行不起來強多啦急迂!
數(shù)據(jù)準備
話說想做這件事許久啦影所,有時很好奇別人講的是哪里話,又不好意思問僚碎?就萌生了個訓練方言識別是哪里人的想法猴娩,當然,相比方言意思識別簡單超多的還是。這里還是佩服國家隊的中國電信卷中,開源了30種方言的大語言模型矛双。
于是,我也決定利用開放的數(shù)據(jù)集訓練個分類模型玩玩蟆豫。本來想用R-Torch的议忽,發(fā)現(xiàn)自己處理的操作太多啦,主要是R語言深度學習真的不是主流十减。不小心發(fā)現(xiàn)了一個開箱即用的栈幸,于是,折騰起來:yeyupiaoling/AudioClassification-Pytorch
代碼準備
就嚴格按照作者的python版本等進行的帮辟,沒有例外速址,開源軟件的版本兼容是令人一言難盡的,所以由驹,能一致盡量完全一致芍锚!我是win11,不過git這種操作用的是WSL2進行的蔓榄。我一般用WSL2操作windows下的目錄并炮,這樣讀寫性能損失超多,只是不想文件刪除后還不能釋放空間(WSL2是個特殊虛擬機)
# 軟件安裝
conda create -n python=3.8
# pytorch等
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
# mcls
python -m pip install macls -U -i https://pypi.tuna.tsinghua.edu.cn/simple
# 用的repo
git clone https://github.com/yeyupiaoling/AudioClassification-Pytorch.git
cd AudioClassification-Pytorch/
pip install .
# 齊活润樱!
數(shù)據(jù)下載
經(jīng)過選擇渣触,最終是用了這個KeSpeech
,不過后面發(fā)現(xiàn)壹若,我直接用的這個repo也是有個方言數(shù)據(jù)集3dspeaker_data可用的嗅钻。不過下載一個已經(jīng)用去幾百G的空間,不想再下第二個啦店展!另外养篓,礙于電腦配置,也用不了那么多數(shù)據(jù)赂蕴,這里我還從里面節(jié)選了幾分之一做訓練呢柳弄!
測試和訓練集的截取劃分
我的數(shù)據(jù)準備過程比較傻瓜,直接用最基本的AI輔助編碼寫了兩個腳本實現(xiàn)的概说,簡單的說就是讀取兩個文本文件內(nèi)容碧注,建立兩個字典,然后匹配糖赔,抽取前1200條數(shù)據(jù)萍丐,生成訓練集,再抽取200多用于測試集放典。明顯數(shù)據(jù)是偏少的逝变。
# every class first 1000
i = 0
dic = {}
fout = open('D:/Projects/dialect/KeSpeech/Metadata/train_audio_path', 'w')
with open('D:/Projects/dialect/KeSpeech/Metadata/phase1.utt2subdialect', 'r') as f:
for line in f:
dialect = line.split('\t')[1]
if dialect not in dic.keys():
dic[dialect] = 1
else:
dic[dialect] += 1
if dic[dialect] > 250 and dic[dialect] <1200:
fout.write(line)
fout.close()
dic_file = dict()
with open('D:/Projects/dialect/KeSpeech/file.txt') as f1:
for line in f1:
if line.endswith('wav\n'):
fi = line.strip().split('/')[3]
dic_file[fi] = line.strip()
# print(fi,dic_file[fi])
# 訓練和測試集一個腳本生成的基茵,改了下名字,按說該整個函數(shù)的
fout2 = open('D:/Projects/dialect/KeSpeech/Metadata/test_list.txt', 'a')
with open('D:/Projects/dialect/KeSpeech/Metadata/test_audio_path', 'r') as f2:
for line in f2:
file_name = line.split('\t')[0] + '.wav'
if file_name in dic_file.keys():
fout2.write(dic_file[file_name]+'\t' + line.strip().split('\t')[1] +'\n')
else:
print(file_name)
# break
fout2.close()
# label file生成
labels_dict = {0: 'Mandarin', 3: 'Northeastern', 2: 'Jiang-Huai',
3: 'Southwestern', 4: 'Jiao-Liao', 5: 'Beijing', 6: 'Zhongyuan',
7: 'Ji-Lu', 8: 'Lan-Yin'}
with open('D:/Projects/dialect/AudioClassification-Pytorch/dataset/label_list.txt', 'w', encoding='utf-8') as f:
for i in range(len(labels_dict)):
f.write(f'{labels_dict[i]}\n')
這里還對標簽做了替換壳影,因為標簽只能是數(shù)字拱层。
sed -i s/'Mandarin'/0/g test_list.txt
sed -i s/'Northeastern'/1/g test_list.txt
sed -i s/'Mandarin'/0/g s/'Northeastern'/1/g s/'Jiang-Huai'/2/g test_list.txt
sed -i s/'Jiang-Huai'/2/g test_list.txt
sed -i s/'Southwestern'/3/g test_list.txt
sed -i s/'Jiao-Liao'/4/g test_list.txt
sed -i s/'Beijing'/5/g test_list.txt
sed -i s/'Zhongyuan'/6/g test_list.txt
sed -i s/'Ji-Lu'/7/g test_list.txt
sed -i s/'Lan-Yin'/8/g test_list.txt
完成這些,把文件放入新建的dataset文件夾宴咧,就可以愉快地訓練啦根灯,可以斷點續(xù)訓練的哦!點贊掺栅!
訓練
特征提取
這是第一步箱吕,耗時并不多python extract_features.py --configs=configs/cam++.yml --save_dir=dataset/features
漫長的訓練
前后大概花了三天時間,每個兩到三個小時的樣子柿冲。也是一條命令的事!
python train.py
評估效果
雖然初次訓練結果不好兆旬,至少假抄,成功獲得了人生第一個語音分類模型,還是極開心的丽猬,感謝作者宿饱!
python eval.py --configs=configs/cam++.yml
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:14 - ----------- 額外配置參數(shù) -----------
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:16 - configs: configs/cam++.yml
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:16 - resume_model: models/CAMPPlus_Fbank/best_model/
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:16 - save_matrix_path: output/images/
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:16 - use_gpu: True
[2024-06-05 20:41:03.545426 INFO ] utils:print_arguments:17 - ------------------------------------------------
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:19 - ----------- 配置文件參數(shù) -----------
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:22 - dataset_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - aug_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - noise_aug_prob: 0.2
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - noise_dir: dataset/noise
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - speed_perturb: True
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - volume_aug_prob: 0.2
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - volume_perturb: False
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - dataLoader:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - batch_size: 39
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - drop_last: True
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - num_workers: 4
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - do_vad: False
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - eval_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - batch_size: 39
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - max_duration: 10
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - label_list_path: dataset/label_list.txt
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - max_duration: 3
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - min_duration: 0.5
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - sample_rate: 16000
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - spec_aug_args:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - freq_mask_width: [0, 8]
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - time_mask_width: [0, 10]
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - target_dB: -20
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - test_list: dataset/test_list.txt
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - train_list: dataset/train_list.txt
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - use_dB_normalization: True
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - use_spec_aug: True
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:22 - model_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - num_class: None
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:22 - optimizer_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - learning_rate: 0.001
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - optimizer: Adam
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - scheduler: WarmupCosineSchedulerLR
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - scheduler_args:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - max_lr: 0.001
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - min_lr: 1e-05
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - warmup_epoch: 5
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - weight_decay: 1e-06
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:22 - preprocess_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - feature_method: Fbank
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:25 - method_args:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - num_mel_bins: 80
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:27 - sample_frequency: 16000
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:22 - train_conf:
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - enable_amp: False
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - log_interval: 10
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - loss_weight: None
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - max_epoch: 60
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:29 - use_compile: False
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:31 - use_model: CAMPPlus
[2024-06-05 20:41:03.576924 INFO ] utils:print_arguments:32 - ------------------------------------------------
[2024-06-05 20:41:03.576924 WARNING] trainer:__init__:74 - Windows系統(tǒng)不支持多線程讀取數(shù)據(jù),已自動關閉脚祟!
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
CAMPPlus [1, 9] --
├─FCM: 1-1 [1, 320, 98] --
│ └─Conv2d: 2-1 [1, 32, 80, 98] 288
│ └─BatchNorm2d: 2-2 [1, 32, 80, 98] 64
│ └─Sequential: 2-3 [1, 32, 40, 98] --
│ │ └─BasicResBlock: 3-1 [1, 32, 40, 98] 19,648
│ │ └─BasicResBlock: 3-2 [1, 32, 40, 98] 18,560
│ └─Sequential: 2-4 [1, 32, 20, 98] --
│ │ └─BasicResBlock: 3-3 [1, 32, 20, 98] 19,648
│ │ └─BasicResBlock: 3-4 [1, 32, 20, 98] 18,560
│ └─Conv2d: 2-5 [1, 32, 10, 98] 9,216
│ └─BatchNorm2d: 2-6 [1, 32, 10, 98] 64
├─Sequential: 1-2 [1, 512] --
│ └─TDNNLayer: 2-7 [1, 128, 49] --
│ │ └─Conv1d: 3-5 [1, 128, 49] 204,800
│ │ └─Sequential: 3-6 [1, 128, 49] 256
│ └─CAMDenseTDNNBlock: 2-8 [1, 512, 49] --
│ │ └─CAMDenseTDNNLayer: 3-7 [1, 32, 49] 39,520
│ │ └─CAMDenseTDNNLayer: 3-8 [1, 32, 49] 43,680
│ │ └─CAMDenseTDNNLayer: 3-9 [1, 32, 49] 47,840
│ │ └─CAMDenseTDNNLayer: 3-10 [1, 32, 49] 52,000
│ │ └─CAMDenseTDNNLayer: 3-11 [1, 32, 49] 56,160
│ │ └─CAMDenseTDNNLayer: 3-12 [1, 32, 49] 60,320
│ │ └─CAMDenseTDNNLayer: 3-13 [1, 32, 49] 64,480
│ │ └─CAMDenseTDNNLayer: 3-14 [1, 32, 49] 68,640
│ │ └─CAMDenseTDNNLayer: 3-15 [1, 32, 49] 72,800
│ │ └─CAMDenseTDNNLayer: 3-16 [1, 32, 49] 76,960
│ │ └─CAMDenseTDNNLayer: 3-17 [1, 32, 49] 81,120
│ │ └─CAMDenseTDNNLayer: 3-18 [1, 32, 49] 85,280
│ └─TransitLayer: 2-9 [1, 256, 49] --
│ │ └─Sequential: 3-19 [1, 512, 49] 1,024
│ │ └─Conv1d: 3-20 [1, 256, 49] 131,072
│ └─CAMDenseTDNNBlock: 2-10 [1, 1024, 49] --
│ │ └─CAMDenseTDNNLayer: 3-21 [1, 32, 49] 56,160
│ │ └─CAMDenseTDNNLayer: 3-22 [1, 32, 49] 60,320
│ │ └─CAMDenseTDNNLayer: 3-23 [1, 32, 49] 64,480
│ │ └─CAMDenseTDNNLayer: 3-24 [1, 32, 49] 68,640
│ │ └─CAMDenseTDNNLayer: 3-25 [1, 32, 49] 72,800
│ │ └─CAMDenseTDNNLayer: 3-26 [1, 32, 49] 76,960
│ │ └─CAMDenseTDNNLayer: 3-27 [1, 32, 49] 81,120
│ │ └─CAMDenseTDNNLayer: 3-28 [1, 32, 49] 85,280
│ │ └─CAMDenseTDNNLayer: 3-29 [1, 32, 49] 89,440
│ │ └─CAMDenseTDNNLayer: 3-30 [1, 32, 49] 93,600
│ │ └─CAMDenseTDNNLayer: 3-31 [1, 32, 49] 97,760
│ │ └─CAMDenseTDNNLayer: 3-32 [1, 32, 49] 101,920
│ │ └─CAMDenseTDNNLayer: 3-33 [1, 32, 49] 106,080
│ │ └─CAMDenseTDNNLayer: 3-34 [1, 32, 49] 110,240
│ │ └─CAMDenseTDNNLayer: 3-35 [1, 32, 49] 114,400
│ │ └─CAMDenseTDNNLayer: 3-36 [1, 32, 49] 118,560
│ │ └─CAMDenseTDNNLayer: 3-37 [1, 32, 49] 122,720
│ │ └─CAMDenseTDNNLayer: 3-38 [1, 32, 49] 126,880
│ │ └─CAMDenseTDNNLayer: 3-39 [1, 32, 49] 131,040
│ │ └─CAMDenseTDNNLayer: 3-40 [1, 32, 49] 135,200
│ │ └─CAMDenseTDNNLayer: 3-41 [1, 32, 49] 139,360
│ │ └─CAMDenseTDNNLayer: 3-42 [1, 32, 49] 143,520
│ │ └─CAMDenseTDNNLayer: 3-43 [1, 32, 49] 147,680
│ │ └─CAMDenseTDNNLayer: 3-44 [1, 32, 49] 151,840
│ └─TransitLayer: 2-11 [1, 512, 49] --
│ │ └─Sequential: 3-45 [1, 1024, 49] 2,048
│ │ └─Conv1d: 3-46 [1, 512, 49] 524,288
│ └─CAMDenseTDNNBlock: 2-12 [1, 1024, 49] --
│ │ └─CAMDenseTDNNLayer: 3-47 [1, 32, 49] 89,440
│ │ └─CAMDenseTDNNLayer: 3-48 [1, 32, 49] 93,600
│ │ └─CAMDenseTDNNLayer: 3-49 [1, 32, 49] 97,760
│ │ └─CAMDenseTDNNLayer: 3-50 [1, 32, 49] 101,920
│ │ └─CAMDenseTDNNLayer: 3-51 [1, 32, 49] 106,080
│ │ └─CAMDenseTDNNLayer: 3-52 [1, 32, 49] 110,240
│ │ └─CAMDenseTDNNLayer: 3-53 [1, 32, 49] 114,400
│ │ └─CAMDenseTDNNLayer: 3-54 [1, 32, 49] 118,560
│ │ └─CAMDenseTDNNLayer: 3-55 [1, 32, 49] 122,720
│ │ └─CAMDenseTDNNLayer: 3-56 [1, 32, 49] 126,880
│ │ └─CAMDenseTDNNLayer: 3-57 [1, 32, 49] 131,040
│ │ └─CAMDenseTDNNLayer: 3-58 [1, 32, 49] 135,200
│ │ └─CAMDenseTDNNLayer: 3-59 [1, 32, 49] 139,360
│ │ └─CAMDenseTDNNLayer: 3-60 [1, 32, 49] 143,520
│ │ └─CAMDenseTDNNLayer: 3-61 [1, 32, 49] 147,680
│ │ └─CAMDenseTDNNLayer: 3-62 [1, 32, 49] 151,840
│ └─TransitLayer: 2-13 [1, 512, 49] --
│ │ └─Sequential: 3-63 [1, 1024, 49] 2,048
│ │ └─Conv1d: 3-64 [1, 512, 49] 524,288
│ └─Sequential: 2-14 [1, 512, 49] --
│ │ └─BatchNorm1d: 3-65 [1, 512, 49] 1,024
│ │ └─ReLU: 3-66 [1, 512, 49] --
│ └─StatsPool: 2-15 [1, 1024] --
│ └─DenseLayer: 2-16 [1, 512] --
│ │ └─Conv1d: 3-67 [1, 512, 1] 524,288
│ │ └─Sequential: 3-68 [1, 512] --
├─Linear: 1-3 [1, 9] 4,617
===============================================================================================
Total params: 7,180,841
Trainable params: 7,180,841
Non-trainable params: 0
Total mult-adds (M): 552.44
===============================================================================================
Input size (MB): 0.03
Forward/backward pass size (MB): 41.22
Params size (MB): 28.72
Estimated Total Size (MB): 69.98
===============================================================================================
[2024-06-05 20:41:06.114035 INFO ] trainer:evaluate:476 - 成功加載模型:models/CAMPPlus_Fbank/best_model/model.pth
100%|██████████████████████████████████████████████████████████████████████████████████| 58/58 [00:25<00:00, 2.24it/s]
評估消耗時間:28s谬以,loss:3.31887,accuracy:0.26412
訓練的準確度由桌,測試集試了下为黎,過擬合太太太嚴重啦,主要是訓練數(shù)據(jù)不夠行您,另外就是數(shù)據(jù)比例可能和實際不一樣吧铭乾,前者是主要原因。