set tf.device('/cpu:0') for render op explicitly
This commit is contained in:
Родитель
c6fd2eb1c5
Коммит
9121818a7b
|
@ -294,34 +294,11 @@ class Face3D():
|
||||||
#using tf_mesh_renderer for rasterization (https://github.com/google/tf_mesh_renderer)
|
#using tf_mesh_renderer for rasterization (https://github.com/google/tf_mesh_renderer)
|
||||||
# img: [batchsize,224,224,3] images in RGB order (0-255)
|
# img: [batchsize,224,224,3] images in RGB order (0-255)
|
||||||
# mask:[batchsize,224,224,1] transparency for img ({0,1} value)
|
# mask:[batchsize,224,224,1] transparency for img ({0,1} value)
|
||||||
img_rgba = mesh_renderer.mesh_renderer(face_shape,
|
with tf.device('/cpu:0'):
|
||||||
tf.cast(facemodel.face_buf-1,tf.int32),
|
img_rgba = mesh_renderer.mesh_renderer(face_shape,
|
||||||
face_norm,
|
tf.cast(facemodel.face_buf-1,tf.int32),
|
||||||
face_color,
|
face_norm,
|
||||||
camera_position = camera_position,
|
face_color,
|
||||||
camera_lookat = camera_lookat,
|
|
||||||
camera_up = camera_up,
|
|
||||||
light_positions = light_positions,
|
|
||||||
light_intensities = light_intensities,
|
|
||||||
image_width = 224,
|
|
||||||
image_height = 224,
|
|
||||||
fov_y = fov_y,
|
|
||||||
near_clip = 0.01,
|
|
||||||
far_clip = 50.0,
|
|
||||||
ambient_color = ambient_color)
|
|
||||||
|
|
||||||
img = img_rgba[:,:,:,:3]
|
|
||||||
mask = img_rgba[:,:,:,3:]
|
|
||||||
|
|
||||||
img = tf.cast(img[:,:,:,::-1],tf.float32) #transfer RGB to BGR
|
|
||||||
mask = tf.cast(mask,tf.float32) # full face region
|
|
||||||
|
|
||||||
if is_train:
|
|
||||||
# compute mask for small face region
|
|
||||||
img_crop_rgba = mesh_renderer.mesh_renderer(mask_face_shape,
|
|
||||||
tf.cast(facemodel.mask_face_buf-1,tf.int32),
|
|
||||||
mask_face_norm,
|
|
||||||
mask_face_color,
|
|
||||||
camera_position = camera_position,
|
camera_position = camera_position,
|
||||||
camera_lookat = camera_lookat,
|
camera_lookat = camera_lookat,
|
||||||
camera_up = camera_up,
|
camera_up = camera_up,
|
||||||
|
@ -334,6 +311,31 @@ class Face3D():
|
||||||
far_clip = 50.0,
|
far_clip = 50.0,
|
||||||
ambient_color = ambient_color)
|
ambient_color = ambient_color)
|
||||||
|
|
||||||
|
img = img_rgba[:,:,:,:3]
|
||||||
|
mask = img_rgba[:,:,:,3:]
|
||||||
|
|
||||||
|
img = tf.cast(img[:,:,:,::-1],tf.float32) #transfer RGB to BGR
|
||||||
|
mask = tf.cast(mask,tf.float32) # full face region
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
# compute mask for small face region
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
img_crop_rgba = mesh_renderer.mesh_renderer(mask_face_shape,
|
||||||
|
tf.cast(facemodel.mask_face_buf-1,tf.int32),
|
||||||
|
mask_face_norm,
|
||||||
|
mask_face_color,
|
||||||
|
camera_position = camera_position,
|
||||||
|
camera_lookat = camera_lookat,
|
||||||
|
camera_up = camera_up,
|
||||||
|
light_positions = light_positions,
|
||||||
|
light_intensities = light_intensities,
|
||||||
|
image_width = 224,
|
||||||
|
image_height = 224,
|
||||||
|
fov_y = fov_y,
|
||||||
|
near_clip = 0.01,
|
||||||
|
far_clip = 50.0,
|
||||||
|
ambient_color = ambient_color)
|
||||||
|
|
||||||
mask_f = img_crop_rgba[:,:,:,3:]
|
mask_f = img_crop_rgba[:,:,:,3:]
|
||||||
mask_f = tf.cast(mask_f,tf.float32) # small face region
|
mask_f = tf.cast(mask_f,tf.float32) # small face region
|
||||||
return img,mask,mask_f
|
return img,mask,mask_f
|
||||||
|
|
Загрузка…
Ссылка в новой задаче