Minor Update test.py
This commit is contained in:
Родитель
e6b6271028
Коммит
b01c01a061
|
@ -12,7 +12,6 @@ from PIL import Image
|
|||
import torch
|
||||
import torchvision.utils as vutils
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
@ -140,11 +139,11 @@ if __name__ == "__main__":
|
|||
if opt.NL_use_mask:
|
||||
mask_name = mask_loader[i]
|
||||
mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB")
|
||||
if opt.mask_dilation!=0:
|
||||
kernel=np.ones((3,3),np.uint8)
|
||||
mask=np.array(mask)
|
||||
mask=cv2.dilate(mask,kernel,iterations=opt.mask_dilation)
|
||||
mask=Image.fromarray(mask.astype('uint8'))
|
||||
if opt.mask_dilation != 0:
|
||||
kernel = np.ones((3,3),np.uint8)
|
||||
mask = np.array(mask)
|
||||
mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation)
|
||||
mask = Image.fromarray(mask.astype('uint8'))
|
||||
origin = input
|
||||
input = irregular_hole_synthesize(input, mask)
|
||||
mask = mask_transform(mask)
|
||||
|
@ -190,5 +189,4 @@ if __name__ == "__main__":
|
|||
normalize=True,
|
||||
)
|
||||
|
||||
origin.save(opt.outputs_dir + "/origin/" + input_name)
|
||||
|
||||
origin.save(opt.outputs_dir + "/origin/" + input_name)
|
Загрузка…
Ссылка в новой задаче