如圖所示卜高,做多類別分類戈泼,每個文件夾代表一個類別,所有圖像均為NIFTI格式眼滤,如何加載進 MONAI 進行訓練巴席?
在這之前,我們來看看 MONAI dataset 加載方法:
MONAI dataset 的數(shù)據(jù)(image, label)輸入有兩種形式诅需,一種是 array(數(shù)組)漾唉, 一種是dict(字典)。
簡單區(qū)分一下
以 array 形式加載數(shù)據(jù)
images = [
"IXI314-IOP-0889-T1.nii.gz",
"IXI249-Guys-1072-T1.nii.gz",
"IXI609-HH-2600-T1.nii.gz",
"IXI173-HH-1590-T1.nii.gz",
"IXI020-Guys-0700-T1.nii.gz",
]
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
train_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
從代碼里很容易看到诱担,images 和 labels 都是 array, 直接作為 ImageDataset 的參數(shù)就行毡证。
以 dict 形式加載數(shù)據(jù)
images = [
"IXI314-IOP-0889-T1.nii.gz",
"IXI249-Guys-1072-T1.nii.gz",
"IXI609-HH-2600-T1.nii.gz",
"IXI173-HH-1590-T1.nii.gz",
"IXI020-Guys-0700-T1.nii.gz",
]
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
train_files = [{"img": img, "label": label} for img, label in zip(images, labels)]
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
這里 images 和 labels 都是 array, 只不過最后會把他們打包成一個字典,使得每個樣本的 image和label相對應(yīng)起來蔫仙。然后傳給 Dataset料睛。
所以,回到最初的問題摇邦,不管用array形式還是dict形式恤煞,我們都需要構(gòu)建一個 images/labels, 其中images里面是每個image的地址,如果是分類問題施籍,labels是每個圖像的類別居扒, 如果是分割問題,則是ground truth的地址丑慎。
進一步的問題是:如何給文件夾的每個圖像定義label喜喂?
當然,這在torchvision中竿裂,有一個函數(shù)可以輕松搞定玉吁!
但是!他的缺點是不可以加載后綴為gz
的文件腻异,但是醫(yī)學圖像大部分都是三維圖像进副,后綴為nii.gz,怎么辦悔常?影斑??
我們可以借鑒他的思路机打,自己寫一個支持 .gz
文件的不就好了矫户。
說干就干
第一種:直接修改源代碼
查看源碼,它不支持 gz
的主要原因是它指定了后綴為下面這些??
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
因為不包含gz
姐帚,所以不支持吏垮。
源碼在torchvision/datasets/folder.py
那一種簡單粗暴地方法就是直接修改 IMG_EXTENSIONS障涯,在后面加一個 '.gz',就可以使用了。
使用案例:
from torchvision.datasets import ImageFolder
data_root = '/dataset'
dataset = ImageFolder(root=data_root)
classes = dataset.classes # 獲得類別名稱(文件夾的名字)
class_to_idx = dataset.class_to_idx # 獲得類別對應(yīng)的索引或標簽
images_labels = dataset.imgs
images = [tup[0] for tup in images_labels] # array
labels = [tup[1] for tup in images_labels] # array
# for dict
train_files = [{'image': tup[0], 'label': tup[1]} for tup in images_labels] # dict
然后就可以傳到上述兩種dataset了膳汪,完美解決????
但是這種方法對源代碼造成了破壞唯蝶,不易移植,雖然簡單粗暴遗嗽,但是不推薦U澄摇!
我們可以根據(jù)他的思路自己寫一個
第二種:構(gòu)建自己的ImageFolder
構(gòu)建思路:
- step 1 獲取文件夾名稱作為classes痹换,并給它標簽征字。
def find_classes(directory: str):
"""Finds the class folders in a dataset.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
[圖片上傳失敗...(image-e2d7b2-1663059142560)]
- step 2 遍歷文件夾,賦予每個圖像標簽
在這一步中娇豫,我們會檢查每個圖像的后綴匙姜。
img_label_dict = []
imgs = []
labels = []
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname): # 判斷后綴是否有效
path = os.path.join(root, fname)
item = {'img': path, 'label': class_index}
img_label_dict.append(item)
imgs.append(path)
labels.append(class_index)
這是關(guān)鍵代碼,不全冯痢。
最后貼上完整代碼
import os
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
# 從 data 根目錄自動獲取不同的類別文件夾氮昧,并自動給文件夾標簽
def find_classes(directory: str):
"""Finds the class folders in a dataset.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
# 檢查 file 的后綴是不是在允許的擴展中
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
# 從根目錄中獲取 圖像的類別,以及自動為類別設(shè)置類標簽浦楣,返回【圖像-標簽對袖肥, 類別名, 類別對應(yīng)的索引等】
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
classes, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
img_label_dict = []
imgs = []
labels = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = {'img': path, 'label': class_index}
img_label_dict.append(item)
imgs.append(path)
labels.append(class_index)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return img_label_dict, imgs, labels, classes, class_to_idx
if __name__ == '__main__':
data_root = 'dataset'
# classes, class_to_idx = find_classes(data_root)
# 允許的擴展名
extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.gz')
img_label_dict, imgs, labels, classes, class_to_idx= make_dataset(data_root, extensions=extensions)
完結(jié)~