在深度學(xué)習(xí)最常用的卷積神經(jīng)網(wǎng)絡(luò)中,要求數(shù)據(jù)為具有空間局部性的多維矩陣或者說(shuō)張量扒寄。這與廣泛應(yīng)用的三維模型格式例如STL這種保存三角面片的存儲(chǔ)方式不一致瞧甩。因此,采用體素化的方式對(duì)輸入進(jìn)行處理灶伊。
以VTK為例疆前,在讀入了vtkPolyData后,采用vtkPolyDataToImageStencil(Example)的方式對(duì)三維模型進(jìn)行轉(zhuǎn)換聘萨,類似的轉(zhuǎn)換方法還有vtkVoxelModeller竹椒,但相比之下效率極低。
不過(guò)米辐,這樣的方法還是較為緩慢胸完,尤其是當(dāng)輸出體素模型規(guī)模較大時(shí)(如128x128x128)书释,在實(shí)際使用中,會(huì)使模型文件讀取占據(jù)了大量開銷赊窥。不過(guò)爆惧,由于這個(gè)轉(zhuǎn)換本身是可以重復(fù)利用的,因此在定義數(shù)據(jù)集時(shí)锨能,加入了cache模式扯再,PyTorch樣例代碼如下:
class Dataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None, cache=False):
self.frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
if cache:
self.cache = [None for i in range(len(self.frame))]
for i in range(len(self.frame)):
print('Caching record #%d\r' % (count))
self.cache[count] = self.read(i)
else:
self.cache = None
def __len__(self):
return len(self.frame)
def read(self, idx):
"""Read your data here."""
return sample
def __getitem__(self, idx):
if self.cache:
sample = self.cache[idx]
else:
sample = self.read(idx)
if self.transform:
sample = self.transform(sample)
return sample
實(shí)踐中發(fā)現(xiàn)這樣建立緩存還是存在讀取效率不足的問(wèn)題,因此再次改寫了一下址遇,變成多線程的形式熄阻。
def __init__(self, csv_file, root_dir, transform=None, cache=False, thread=4):
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
if cache:
self.cache = [None for i in range(len(self.landmarks_frame))]
pool = multiprocessing.Pool(processes=thread)
irange = range(len(self.landmarks_frame))
count = 0
for sample in pool.imap(self.read, irange):
print('Caching record #%d\r' % (count))
self.cache[count] = sample
count += 1
else:
self.cache = None
可惜的是,這樣的改寫并不能成功傲隶,因?yàn)樵趍ultiprocessing中傳遞結(jié)果時(shí)用到了pickle進(jìn)行數(shù)據(jù)的傳遞饺律,而vtkImageData作為比較特殊的對(duì)象無(wú)法被pickle序列化。為了解決這個(gè)問(wèn)題跺株,簡(jiǎn)單調(diào)用了vtk.util.numpy_support
里的一些方法复濒,完成vtkImageData與Numpy array之間的無(wú)損轉(zhuǎn)換。
def voxel2array(self, img):
# Up to support for 3 dimensions for this line
rows, cols, _ = img.GetDimensions()
sc = img.GetPointData().GetScalars()
arr = numpy_support.vtk_to_numpy(sc)
arr = array.reshape(rows, cols, -1)
spacing = img.GetSpacing()
origin = img.GetOrigin()
return arr, spacing, origin
def array2voxel(self, arr, spacing, origin):
vtk_data = numpy_support.numpy_to_vtk(
arr.ravel(), array_type=vtk.VTK_UNSIGNED_CHAR)
img = vtk.vtkImageData()
img.SetDimensions(array.shape)
img.SetSpacing(spacing)
img.SetOrigin(origin)
img.GetPointData().SetScalars(vtk_data)
return img
重點(diǎn)是vtkImageData中還留存著其體素的spacing信息和圖像的整體坐標(biāo)信息乒省。
突然想到巧颈,在體素化前利用一些三維模型降采樣方法對(duì)牙齒模型進(jìn)行降采樣,是否能夠大大加速體素化袖扛。