scDataset 類(lèi)
from collections import Counter
from typing import Optional
import anndata
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import scanpy
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
class scDataset(Dataset):
"""A class that represent a single cell dataset."""
def __init__(self, X, Y, study=None):
self.X = X
self.Y = Y
self.study = study
self.classes = set(self.Y)
def __len__(self):
return len(self.Y)
def __getitem__(self, idx):
# data, label, study
return self.X[idx].A, self.Y[idx], self.study[idx]
這段代碼定義了 scDataset 類(lèi),它是一個(gè)用于表示單細(xì)胞數(shù)據(jù)集的 Python 類(lèi)悴品。這個(gè)類(lèi)繼承自 torch.utils.data.Dataset苔严,使其兼容于 PyTorch 的數(shù)據(jù)加載和處理機(jī)制。下面我將詳細(xì)解釋這個(gè)類(lèi)及其方法的功能:
導(dǎo)入的模塊
collections.Counter: 用于計(jì)數(shù)不同元素的出現(xiàn)次數(shù)欠窒。
typing.Optional: 用于類(lèi)型注解岖妄,表示參數(shù)可以為 None寂祥。
anndata: 用于處理單細(xì)胞數(shù)據(jù)的庫(kù)。
pytorch_lightning: 簡(jiǎn)化 PyTorch 模型訓(xùn)練的庫(kù)福扬。
scanpy: 用于單細(xì)胞數(shù)據(jù)分析的庫(kù)铛碑。
構(gòu)造函數(shù) init(self, X, Y, study=None)
參數(shù):
X: 數(shù)據(jù)矩陣虽界,通常是一個(gè)稀疏矩陣莉御,包含細(xì)胞的基因表達(dá)數(shù)據(jù)。
Y: 標(biāo)簽數(shù)組牍颈,包含與 X 中每個(gè)樣本相對(duì)應(yīng)的標(biāo)簽(例如晴圾,細(xì)胞類(lèi)型)死姚。
study: 可選參數(shù)勤篮,包含與 X 中每個(gè)樣本相對(duì)應(yīng)的研究或?qū)嶒?yàn)信息。
屬性初始化:
self.X: 存儲(chǔ)傳入的數(shù)據(jù)矩陣 X账劲。
self.Y: 存儲(chǔ)傳入的標(biāo)簽 Y瀑焦。
self.study: 存儲(chǔ)傳入的研究信息 study。
self.classes: 從 Y 中提取的唯一標(biāo)簽集合铺董,表示數(shù)據(jù)中包含的不同類(lèi)別精续。
len(self) 方法
返回?cái)?shù)據(jù)集中樣本的總數(shù)粹懒。這是通過(guò)計(jì)算 Y(標(biāo)簽數(shù)組)的長(zhǎng)度來(lái)實(shí)現(xiàn)的。
getitem(self, idx) 方法
參數(shù): idx - 請(qǐng)求的樣本索引确垫。
返回: 三元組 (data, label, study)森爽,其中:
data: 索引 idx 處的樣本數(shù)據(jù)(從 X 中提认怠)。.A 用于將稀疏矩陣轉(zhuǎn)換為常規(guī)數(shù)組付呕。
label: 索引 idx 處的樣本標(biāo)簽(從 Y 中提然罩啊)佩厚。
study: 索引 idx 處的樣本對(duì)應(yīng)的研究信息(從 study 中提取)潮瓶。
DataLoader
在 PyTorch Lightning 框架中毯辅,MetricLearningDataModule 類(lèi)繼承自 pl.LightningDataModule煞额。
DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
sampler=self.get_sampler_weights(self.train_dataset),
collate_fn=self.collate)
DataLoader 是一個(gè)極其重要的工具,用于批量加載數(shù)據(jù)并為訓(xùn)練提供必要的輸入胀莹。為了使數(shù)據(jù)能夠被 DataLoader 正確讀取和處理,需要遵循特定的格式和協(xié)議涩僻。這就是為什么 scDataset 類(lèi)按照特定方式實(shí)現(xiàn)的原因逆日。
如何 scDataset 被 DataLoader 讀取:
- 繼承自 Dataset:
scDataset 繼承自 PyTorch 的 Dataset 類(lèi)萄凤,這意味著它需要實(shí)現(xiàn)兩個(gè)方法:len 和 getitem靡努。
len 返回?cái)?shù)據(jù)集中的樣本數(shù)。
getitem 根據(jù)給定的索引返回相應(yīng)的樣本兽泄。
- 數(shù)據(jù)格式:
getitem 返回一個(gè)包含數(shù)據(jù)病梢、標(biāo)簽和研究信息的元組梁肿。這是標(biāo)準(zhǔn)的 PyTorch 數(shù)據(jù)格式,允許 DataLoader 以一致的方式處理不同類(lèi)型的數(shù)據(jù)集钮热。
get_sampler_weights
其中隧期, get_sampler_weights(self, dataset)
這個(gè)函數(shù)用于根據(jù)數(shù)據(jù)集生成加權(quán)隨機(jī)采樣器赘娄。加權(quán)隨機(jī)采樣器在數(shù)據(jù)不平衡的情況下非常有用,它可以確保在訓(xùn)練過(guò)程中各類(lèi)別被均等地表示鸵闪。
def get_sampler_weights(self, dataset: scDataset) -> WeightedRandomSampler:
"""Get weighted random sampler.
WeightedRandomSampler
A WeightedRandomSampler object.
"""
if dataset.study is None:
class_sample_count = Counter(dataset.Y)
sample_weights = torch.Tensor(
[1.0 / class_sample_count[t] for t in dataset.Y]
)
else:
class_sample_count = Counter(dataset.Y)
study_sample_count = Counter(dataset.study)
sample_weights = torch.Tensor(
[
1.0
/ class_sample_count[dataset.Y[i]]
/ np.log(study_sample_count[dataset.study[i]])
for i in range(len(dataset.Y))
]
)
return WeightedRandomSampler(sample_weights, len(sample_weights))
實(shí)現(xiàn)邏輯
參數(shù):dataset 是一個(gè) scDataset 類(lèi)的實(shí)例蚌讼,包含數(shù)據(jù)集的特征个榕、標(biāo)簽和其他信息。
處理:
如果 dataset 沒(méi)有提供 study 信息凰萨,則根據(jù)類(lèi)別標(biāo)簽 Y 計(jì)算每個(gè)類(lèi)別的樣本計(jì)數(shù)胖眷。然后霹崎,為每個(gè)樣本計(jì)算權(quán)重,權(quán)重為類(lèi)別的倒數(shù)境析。
如果提供了 study 信息劳淆,則同時(shí)考慮類(lèi)別和研究的影響默赂。在這種情況下,樣本的權(quán)重是類(lèi)別和研究的頻率的對(duì)數(shù)的倒數(shù)谒臼。
返回:返回一個(gè) WeightedRandomSampler 對(duì)象耀里,用于在數(shù)據(jù)加載過(guò)程中按照計(jì)算的權(quán)重隨機(jī)選擇樣本冯挎。
collate
def collate(self, batch):
"""Collate tensors.
Parameters
----------
batch:
Batch to collate.
Returns
-------
tuple
A Tuple[torch.Tensor, torch.Tensor, list] containing information
on the collated tensors.
"""
profiles, labels, studies = tuple(
map(list, zip(*batch))
) # tuple([list(t) for t in zip(*batch)])
return (
torch.squeeze(torch.Tensor(np.vstack(profiles))),
torch.Tensor(labels),
studies)
DataLoader 通過(guò) collate_fn 參數(shù)接收一個(gè)函數(shù)房官,該函數(shù)定義了如何將多個(gè)樣本組合成一個(gè)批次。這在處理不規(guī)則大小或不同類(lèi)型的數(shù)據(jù)時(shí)尤其重要孵奶。
collate 函數(shù)的工作流程如下:
輸入:batch蜡峰,一個(gè)包含多個(gè)從 scDataset.getitem 返回的元組的列表。
處理:
使用 zip(*batch) 將批次中的元素分解為單獨(dú)的列表(profiles, labels, studies)载绿。
將每個(gè)列表轉(zhuǎn)換為適當(dāng)?shù)?PyTorch 張量或保持為列表(如研究信息)。
對(duì)于數(shù)據(jù) profiles怀浆,使用 np.vstack 將它們垂直堆疊成一個(gè) NumPy 數(shù)組执赡,然后轉(zhuǎn)換為一個(gè) PyTorch 張量函筋。
返回:一個(gè)包含處理后的數(shù)據(jù)張量、標(biāo)簽張量和研究信息列表的元組
為什么這里需要一個(gè)collate_fn灌诅?
在 PyTorch 中含末,collate_fn 用于在數(shù)據(jù)加載過(guò)程中將多個(gè)樣本組合成一個(gè)批次佣盒。通常,如果你的數(shù)據(jù)集返回的每個(gè)樣本是一個(gè)簡(jiǎn)單的張量(比如圖片或標(biāo)簽)盯仪,你不需要提供一個(gè)自定義的 collate_fn蜜葱,因?yàn)?PyTorch 的默認(rèn) collate_fn 已經(jīng)可以處理這種情況。
然而爸黄,如果你的數(shù)據(jù)集返回的是復(fù)雜的數(shù)據(jù)結(jié)構(gòu)或需要特殊處理(比如不同的數(shù)據(jù)類(lèi)型組合炕贵、不規(guī)則的張量形狀等)野崇,那么你可能需要提供一個(gè)自定義的 collate_fn 來(lái)正確地處理這些數(shù)據(jù)。
為什么 scDataset 類(lèi)需要 collate_fn鳖轰?
scDataset 類(lèi)返回三種不同類(lèi)型的數(shù)據(jù):self.X(數(shù)據(jù)矩陣),self.Y(標(biāo)簽),和 self.study(研究)狈惫。這些數(shù)據(jù)可能需要特殊處理才能合并為一個(gè)批次,尤其是當(dāng)它們包含不同類(lèi)型的數(shù)據(jù)時(shí)忆肾。例如:
數(shù)據(jù)轉(zhuǎn)換:self.X 可能是一個(gè)稀疏矩陣客冈,需要轉(zhuǎn)換為密集張量稳强。
數(shù)據(jù)維度對(duì)齊:如果 self.X 中的樣本有不同的形狀,可能需要進(jìn)行填充或裁剪以確保它們可以合并渠缕。
額外信息合并:self.Y 和 self.study 可能需要特殊處理才能與 self.X 正確對(duì)應(yīng)褒繁。
什么時(shí)候不需要寫(xiě) collate_fn棒坏?
如果你的數(shù)據(jù)集返回的每個(gè)樣本已經(jīng)是一個(gè)規(guī)則的張量,且不需要任何特殊的預(yù)處理或后處理徒探,那么就不需要提供自定義的 collate_fn喂窟。在這種情況下,PyTorch 的默認(rèn) collate_fn 足以應(yīng)對(duì)大多數(shù)情況偷溺,它會(huì)自動(dòng)將多個(gè)樣本堆疊成一個(gè)批次挫掏。例如秩命,如果你的數(shù)據(jù)集只返回一組圖片和對(duì)應(yīng)的標(biāo)簽褒傅,而且所有圖片都有相同的形狀殿托,那么默認(rèn)的 collate_fn 就足夠了剧蚣。