要使用純 PyTorch 訓(xùn)練模型炮姨,而不是依賴于 PyTorch Lightning捌刮,需要手動實(shí)現(xiàn)訓(xùn)練循環(huán)、驗(yàn)證舒岸、測試步驟以及優(yōu)化器的配置绅作。
- 準(zhǔn)備數(shù)據(jù)集和數(shù)據(jù)加載器(Data Loaders)
使用 scDataset 類來創(chuàng)建訓(xùn)練、驗(yàn)證和測試數(shù)據(jù)集蛾派。創(chuàng)建 PyTorch 的 DataLoader 實(shí)例俄认,用于加載數(shù)據(jù)。 - 構(gòu)建模型
初始化模型(如 Encoder 和 Decoder)洪乍。 - 定義優(yōu)化器和損失函數(shù)
設(shè)置優(yōu)化器眯杏,例如 Adam 或 SGD。定義損失函數(shù)壳澳,例如三元組損失和均方誤差損失岂贩。 - 訓(xùn)練循環(huán)
對數(shù)據(jù)進(jìn)行迭代,執(zhí)行正向傳播钾埂、計(jì)算損失河闰、進(jìn)行反向傳播和優(yōu)化器步驟。 - 驗(yàn)證和測試
在訓(xùn)練過程中或之后褥紫,對驗(yàn)證和測試數(shù)據(jù)集進(jìn)行評估姜性。
以下是如何在 PyTorch 中實(shí)現(xiàn)這些步驟的示例代碼:
import torch
from torch.utils.data import DataLoader
# 1. 準(zhǔn)備數(shù)據(jù)集和數(shù)據(jù)加載器
train_dataset = scDataset(...) # 使用適當(dāng)?shù)膮?shù)填充
train_loader = DataLoader(train_dataset, batch_size=..., num_workers=..., sampler=...)
val_dataset = scDataset(...)
val_loader = DataLoader(val_dataset, batch_size=..., num_workers=...)
# 2. 構(gòu)建模型
encoder = Encoder(...)
decoder = Decoder(...)
# 3. 定義優(yōu)化器和損失函數(shù)
optimizer = torch.optim.Adam([...], lr=...)
triplet_loss_fn = TripletLoss(...)
mse_loss_fn = torch.nn.MSELoss()
# 4. 訓(xùn)練循環(huán)
for epoch in range(num_epochs):
for batch in train_loader:
cells, labels, _ = batch
optimizer.zero_grad()
embeddings = encoder(cells)
reconstructions = decoder(embeddings)
triplet_loss = triplet_loss_fn(embeddings, labels, ...)
reconstruction_loss = mse_loss_fn(cells, reconstructions)
loss = ... # 根據(jù)需要組合損失
loss.backward()
optimizer.step()
# 5. 驗(yàn)證步驟
with torch.no_grad():
for batch in val_loader:
cells, labels, _ = batch
embeddings = encoder(cells)
reconstructions = decoder(embeddings)
# 計(jì)算和記錄驗(yàn)證損失
...
# 測試步驟類似于驗(yàn)證步驟,只是使用測試數(shù)據(jù)集