前言
能看到這篇文章的驳规,都是富貴讓我們相遇。
現(xiàn)在這光景艾君,單GPU都困難采够,何況多GPU訓(xùn)練。腻贰。吁恍。
幾個需要注意的點
- 模型生成部分需要使用tf.distribute.MirroredStrategy
- 為了將batch size的數(shù)據(jù)均等分配給各個GPU的顯存,需要通過tf.data.Dataset.from_generator托管數(shù)據(jù)播演,從迭代器加載冀瓦,同時顯式關(guān)閉AutoShardPolicy。如果不做這一步写烤,顯存分配可能會出問題翼闽,不僅顯存會爆,還可能過程中的validation loss計算會出問題洲炊。
- 為了避免觸發(fā)tensorflow2在完成以上步驟感局,訓(xùn)練過程中metrics的計算bug,需要做到如下幾點暂衡!這個地方是痛點询微,如果不仔細跟蹤,是很難發(fā)現(xiàn)的狂巢!
metrics一定設(shè)置為binary_accuracy撑毛,或者sparse_categorical_accuracy
不能簡單設(shè)置為acc
否則之后會報:as_list() is not defined on an unknown TensorShape的錯誤 - 之所以使用生成器動態(tài)產(chǎn)生訓(xùn)練數(shù)據(jù),不僅僅是為了避免一次性加載訓(xùn)練數(shù)據(jù)唧领,直接吃爆顯存藻雌,還因為需要實時對訓(xùn)練數(shù)據(jù)做數(shù)據(jù)增強與變換,增加模型的魯棒性斩个。
代碼部分
模型生成與編譯部分
直接看tf.distribute.MirroredStrategy的用法胯杭,損失函數(shù),優(yōu)化函數(shù)的根據(jù)自己習(xí)慣來受啥。但是metrics一定不能選擇acc做个!
gpus = tf.config.list_physical_devices('GPU')
batchsize = 8
print('apply: Adam + weighted_bce_dice_loss_v1_7_3')
if len(gpus) > 1:
for gpu in gpus:
tf.config.experimental.set_memory_growth(device=gpu, enable=True)
batchsize *= len(gpus)
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = table_line.get_model(input_shape=(512, 512, 3),
is_resnest_unet=is_resnest_unet,
is_swin_unet=is_swin_unet,
resnest_pretrain_model=resnest_pretrain_model)
# apply custom loss
model.compile(
optimizer=Adam(
lr=0.0001),
loss=weighted_bce_dice_loss_v1_7_3,
metrics=['binary_accuracy'])
else:
model = table_line.get_model(input_shape=(512, 512, 3),
is_resnest_unet=is_resnest_unet,
is_swin_unet=is_swin_unet,
resnest_pretrain_model=resnest_pretrain_model)
model.compile(
optimizer=Adam(
lr=0.0001),
loss=weighted_bce_dice_loss_v1_7_3,
metrics=['binary_accuracy'])
print('batch size: {0}, GPUs: {1}'.format(batchsize, gpus))
數(shù)據(jù)迭代器生成部分
def makeDataset(generator_func,
data_list,
line_path,
batchsize,
draw_line,
is_raw,
need_rotate,
only_flip,
is_wide_line,
strategy=None):
# Get amount of files
ds = tf.data.Dataset.from_generator(generator_func,
args=[data_list, line_path, batchsize,
draw_line, is_raw, need_rotate,
only_flip, is_wide_line],
output_types=(tf.float64, tf.float64))
# Make a dataset from the generator. MAKE SURE TO SPECIFY THE DATA TYPE!!!
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
ds = ds.with_options(options)
# Optional: Make it a distributed dataset if you're using a strategy
if strategy is not None:
ds = strategy.experimental_distribute_dataset(ds)
return ds
獲取training與validation數(shù)據(jù)獲取的迭代器
其中g(shù)en是生成數(shù)據(jù)的方程,其余參數(shù), 除了最后一個strategy參數(shù)滚局,都是生成數(shù)據(jù)方程所需的參數(shù)
training_ds = makeDataset(gen,
data_list=trainP,
line_path=line_path,
batchsize=batchsize,
draw_line=False,
is_raw=is_raw,
need_rotate=need_rotate,
only_flip=only_flip,
is_wide_line=is_wide_line,
strategy=None)
validation_ds = makeDataset(gen,
data_list=testP,
line_path=line_path,
batchsize=batchsize,
draw_line=False,
is_raw=is_raw,
need_rotate=need_rotate,
only_flip=only_flip,
is_wide_line=is_wide_line,
strategy=None)
生成數(shù)據(jù)方程的示例居暖,學(xué)過iterate的都明白在說啥
def gen(paths,
line_path,
batchsize=2,
draw_line=True,
is_raw=False,
need_rotate=False,
only_flip: bool = True,
is_wide_line=False):
num = len(paths)
i = 0
while True:
# sizes = [512,512,512,512,640,1024] ##多尺度訓(xùn)練
# size = np.random.choice(sizes,1)[0]
size = 512
X = np.zeros((batchsize, size, size, 3))
Y = np.zeros((batchsize, size, size, 2))
print(i)
for j in range(batchsize):
while True:
if i >= num:
i = 0
np.random.shuffle(paths)
p = paths[i]
i += 1
try:
if is_raw:
img, lines, labelImg = get_img_label_raw(p,
line_path,
size=(size, size),
draw_line=draw_line,
is_wide_line=is_wide_line)
else:
img, lines, labelImg = get_img_label_transform(p,
line_path,
size=(size, size),
draw_line=draw_line,
need_rotate=need_rotate,
only_flip=only_flip,
is_wide_line=is_wide_line)
break
except Exception as e:
print(e)
X[j] = img
Y[j] = labelImg
yield X, Y
模型訓(xùn)練部分的代碼
訓(xùn)練方法:fit
之前調(diào)用數(shù)據(jù)生成器的訓(xùn)練方法是fit_generator,TF2之后統(tǒng)一用fit方程了
steps參數(shù)的寫法核畴,重點膝但!
注意steps_per_epoch與validation_steps的寫法,batchsize必須與調(diào)用makeDataset時谤草,傳入的batchsize的值相同跟束,否則無法計算出正確的steps
model.fit(training_ds,
callbacks=[checkpointer, earlyStopping],
steps_per_epoch=max(1, len(trainP) // batchsize),
validation_data=validation_ds,
validation_steps=max(1, len(testP) // batchsize),
epochs=300)