transfer reconstruction process from numpy to tensorflow, add rendering process, update image preprocessing method
44
demo.py
|
@ -8,7 +8,7 @@ from scipy.io import loadmat,savemat
|
|||
|
||||
from preprocess_img import Preprocess
|
||||
from load_data import *
|
||||
from reconstruct_mesh import Reconstruction
|
||||
from face_decoder import Face3D
|
||||
|
||||
def load_graph(graph_filename):
|
||||
with tf.gfile.GFile(graph_filename,'rb') as f:
|
||||
|
@ -21,56 +21,68 @@ def demo():
|
|||
# input and output folder
|
||||
image_path = 'input'
|
||||
save_path = 'output'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
img_list = glob.glob(image_path + '/' + '*.png')
|
||||
img_list +=glob.glob(image_path + '/' + '*.jpg')
|
||||
|
||||
# read BFM face model
|
||||
# transfer original BFM model to our model
|
||||
if not os.path.isfile('./BFM/BFM_model_front.mat'):
|
||||
transferBFM09()
|
||||
|
||||
# read face model
|
||||
facemodel = BFM()
|
||||
# read standard landmarks for preprocessing images
|
||||
lm3D = load_lm3d()
|
||||
batchsize = 1
|
||||
n = 0
|
||||
|
||||
# build reconstruction model
|
||||
with tf.Graph().as_default() as graph,tf.device('/cpu:0'):
|
||||
|
||||
images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32)
|
||||
FaceReconstructor = Face3D()
|
||||
images = tf.placeholder(name = 'input_imgs', shape = [batchsize,224,224,3], dtype = tf.float32)
|
||||
graph_def = load_graph('network/FaceReconModel.pb')
|
||||
tf.import_graph_def(graph_def,name='resnet',input_map={'input_imgs:0': images})
|
||||
|
||||
# output coefficients of R-Net (dim = 257)
|
||||
coeff = graph.get_tensor_by_name('resnet/coeff:0')
|
||||
|
||||
# reconstructing faces
|
||||
FaceReconstructor.Reconstruction_Block(coeff,batchsize)
|
||||
face_shape = FaceReconstructor.face_shape_t
|
||||
face_texture = FaceReconstructor.face_texture
|
||||
face_color = FaceReconstructor.face_color
|
||||
landmarks_2d = FaceReconstructor.landmark_p
|
||||
recon_img = FaceReconstructor.render_imgs
|
||||
tri = FaceReconstructor.facemodel.face_buf
|
||||
|
||||
|
||||
with tf.Session() as sess:
|
||||
print('reconstructing...')
|
||||
for file in img_list:
|
||||
n += 1
|
||||
print(n)
|
||||
# load images and corresponding 5 facial landmarks
|
||||
img,lm = load_img(file,file.replace('png','txt'))
|
||||
img,lm = load_img(file,file.replace('png','txt').replace('jpg','txt'))
|
||||
# preprocess input image
|
||||
input_img,lm_new,transform_params = Preprocess(img,lm,lm3D)
|
||||
|
||||
coef = sess.run(coeff,feed_dict = {images: input_img})
|
||||
coeff_,face_shape_,face_texture_,face_color_,landmarks_2d_,recon_img_,tri_ = sess.run([coeff,\
|
||||
face_shape,face_texture,face_color,landmarks_2d,recon_img,tri],feed_dict = {images: input_img})
|
||||
|
||||
# reconstruct 3D face with output coefficients and face model
|
||||
face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d = Reconstruction(coef,facemodel)
|
||||
|
||||
# reshape outputs
|
||||
input_img = np.squeeze(input_img)
|
||||
shape = np.squeeze(face_shape, (0))
|
||||
color = np.squeeze(face_color, (0))
|
||||
landmarks_2d = np.squeeze(landmarks_2d, (0))
|
||||
face_shape_ = np.squeeze(face_shape_, (0))
|
||||
face_texture_ = np.squeeze(face_texture_, (0))
|
||||
face_color_ = np.squeeze(face_color_, (0))
|
||||
landmarks_2d_ = np.squeeze(landmarks_2d_, (0))
|
||||
recon_img_ = np.squeeze(recon_img_, (0))
|
||||
|
||||
# save output files
|
||||
# cropped image, which is the direct input to our R-Net
|
||||
# 257 dim output coefficients by R-Net
|
||||
# 68 face landmarks of cropped image
|
||||
savemat(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','.mat')),{'cropped_img':input_img[:,:,::-1],'coeff':coef,'landmarks_2d':landmarks_2d,'lm_5p':lm_new})
|
||||
save_obj(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','_mesh.obj')),shape,tri,np.clip(color,0,255)/255) # 3D reconstruction face (in canonical view)
|
||||
savemat(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','.mat').replace('jpg','mat')),{'cropped_img':input_img[:,:,::-1],'recon_img':recon_img_,'coeff':coeff_,\
|
||||
'face_shape':face_shape_,'face_texture':face_texture_,'face_color':face_color_,'lm_68p':landmarks_2d_,'lm_5p':lm_new})
|
||||
save_obj(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','_mesh.obj').replace('jpg','_mesh.obj')),face_shape_,tri_,np.clip(face_color_,0,255)/255) # 3D reconstruction face (in canonical view)
|
||||
|
||||
if __name__ == '__main__':
|
||||
demo()
|
||||
|
|
|
@ -0,0 +1,293 @@
|
|||
import tensorflow as tf
|
||||
import math as m
|
||||
import numpy as np
|
||||
import mesh_renderer
|
||||
from scipy.io import loadmat
|
||||
###############################################################################################
|
||||
# Reconstruct 3D face based on output coefficients and facemodel
|
||||
###############################################################################################
|
||||
|
||||
# BFM 3D face model
|
||||
class BFM():
|
||||
def __init__(self,model_path = 'BFM/BFM_model_front.mat'):
|
||||
model = loadmat(model_path)
|
||||
self.meanshape = tf.constant(model['meanshape']) # mean face shape. [3*N,1]
|
||||
self.idBase = tf.constant(model['idBase']) # identity basis. [3*N,80]
|
||||
self.exBase = tf.constant(model['exBase'].astype(np.float32)) # expression basis. [3*N,64]
|
||||
self.meantex = tf.constant(model['meantex']) # mean face texture. [3*N,1] (0-255)
|
||||
self.texBase = tf.constant(model['texBase']) # texture basis. [3*N,80]
|
||||
self.point_buf = tf.constant(model['point_buf']) # triangle indices for each vertex that lies in. starts from 1. [N,8]
|
||||
self.face_buf = tf.constant(model['tri']) # vertex indices in each triangle. starts from 1. [F,3]
|
||||
self.keypoints = tf.squeeze(tf.constant(model['keypoints'])) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
|
||||
|
||||
# Analytic 3D face reconstructor
|
||||
class Face3D():
|
||||
def __init__(self):
|
||||
facemodel = BFM()
|
||||
self.facemodel = facemodel
|
||||
|
||||
# analytic 3D face reconstructions with coefficients from R-Net
|
||||
def Reconstruction_Block(self,coeff,batchsize):
|
||||
#coeff: [batchsize,257] reconstruction coefficients
|
||||
id_coeff,ex_coeff,tex_coeff,angles,translation,gamma = self.Split_coeff(coeff)
|
||||
# [batchsize,N,3] canonical face shape in BFM space
|
||||
face_shape = self.Shape_formation_block(id_coeff,ex_coeff,self.facemodel)
|
||||
# [batchsize,N,3] vertex texture (in RGB order)
|
||||
face_texture = self.Texture_formation_block(tex_coeff,self.facemodel)
|
||||
self.face_texture = face_texture
|
||||
# [batchsize,3,3] rotation matrix for face shape
|
||||
rotation = self.Compute_rotation_matrix(angles)
|
||||
# [batchsize,N,3] vertex normal
|
||||
face_norm = self.Compute_norm(face_shape,self.facemodel)
|
||||
norm_r = tf.matmul(face_norm,rotation)
|
||||
|
||||
# do rigid transformation for face shape using predicted rotation and translation
|
||||
face_shape_t = self.Rigid_transform_block(face_shape,rotation,translation)
|
||||
self.face_shape_t = face_shape_t
|
||||
# compute 2d landmark projections
|
||||
# landmark_p: [batchsize,68,2]
|
||||
face_landmark_t = self.Compute_landmark(face_shape_t,self.facemodel)
|
||||
landmark_p = self.Projection_block(face_landmark_t) # 256*256 image
|
||||
self.landmark_p = landmark_p
|
||||
|
||||
# [batchsize,N,3] vertex color (in RGB order)
|
||||
face_color = self.Illumination_block(face_texture, norm_r, gamma)
|
||||
self.face_color = face_color
|
||||
|
||||
# reconstruction images
|
||||
render_imgs = self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,batchsize)
|
||||
render_imgs = tf.clip_by_value(render_imgs,0,255)
|
||||
render_imgs = tf.cast(render_imgs,tf.float32)
|
||||
self.render_imgs = render_imgs
|
||||
|
||||
######################################################################################################
|
||||
def Split_coeff(self,coeff):
|
||||
|
||||
id_coeff = coeff[:,:80] #identity
|
||||
ex_coeff = coeff[:,80:144] #expression
|
||||
tex_coeff = coeff[:,144:224] #texture
|
||||
angles = coeff[:,224:227] #euler angles for pose
|
||||
gamma = coeff[:,227:254] #lighting
|
||||
translation = coeff[:,254:257] #translation
|
||||
|
||||
return id_coeff,ex_coeff,tex_coeff,angles,translation,gamma
|
||||
|
||||
def Shape_formation_block(self,id_coeff,ex_coeff,facemodel):
|
||||
face_shape = tf.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \
|
||||
tf.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + facemodel.meanshape
|
||||
|
||||
# reshape face shape to [batchsize,N,3]
|
||||
face_shape = tf.reshape(face_shape,[tf.shape(face_shape)[0],-1,3])
|
||||
# re-centering the face shape with mean shape
|
||||
face_shape = face_shape - tf.reshape(tf.reduce_mean(tf.reshape(facemodel.meanshape,[-1,3]),0),[1,1,3])
|
||||
|
||||
return face_shape
|
||||
|
||||
def Compute_norm(self,face_shape,facemodel):
|
||||
shape = face_shape
|
||||
face_id = facemodel.face_buf
|
||||
point_id = facemodel.point_buf
|
||||
|
||||
# face_id and point_id index starts from 1
|
||||
face_id = tf.cast(face_id - 1,tf.int32)
|
||||
point_id = tf.cast(point_id - 1,tf.int32)
|
||||
|
||||
#compute normal for each face
|
||||
v1 = tf.gather(shape,face_id[:,0], axis = 1)
|
||||
v2 = tf.gather(shape,face_id[:,1], axis = 1)
|
||||
v3 = tf.gather(shape,face_id[:,2], axis = 1)
|
||||
e1 = v1 - v2
|
||||
e2 = v2 - v3
|
||||
face_norm = tf.cross(e1,e2)
|
||||
|
||||
face_norm = tf.nn.l2_normalize(face_norm, dim = 2) # normalized face_norm first
|
||||
face_norm = tf.concat([face_norm,tf.zeros([tf.shape(face_shape)[0],1,3])], axis = 1)
|
||||
|
||||
#compute normal for each vertex using one-ring neighborhood
|
||||
v_norm = tf.reduce_sum(tf.gather(face_norm, point_id, axis = 1), axis = 2)
|
||||
v_norm = tf.nn.l2_normalize(v_norm, dim = 2)
|
||||
|
||||
return v_norm
|
||||
|
||||
def Texture_formation_block(self,tex_coeff,facemodel):
|
||||
face_texture = tf.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex
|
||||
|
||||
# reshape face texture to [batchsize,N,3], note that texture is in RGB order
|
||||
face_texture = tf.reshape(face_texture,[tf.shape(face_texture)[0],-1,3])
|
||||
|
||||
return face_texture
|
||||
|
||||
def Compute_rotation_matrix(self,angles):
|
||||
n_data = tf.shape(angles)[0]
|
||||
|
||||
# compute rotation matrix for X-axis, Y-axis, Z-axis respectively
|
||||
rotation_X = tf.concat([tf.ones([n_data,1]),
|
||||
tf.zeros([n_data,3]),
|
||||
tf.reshape(tf.cos(angles[:,0]),[n_data,1]),
|
||||
-tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
|
||||
tf.reshape(tf.cos(angles[:,0]),[n_data,1])],
|
||||
axis = 1
|
||||
)
|
||||
|
||||
rotation_Y = tf.concat([tf.reshape(tf.cos(angles[:,1]),[n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
tf.ones([n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
-tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
tf.reshape(tf.cos(angles[:,1]),[n_data,1])],
|
||||
axis = 1
|
||||
)
|
||||
|
||||
rotation_Z = tf.concat([tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
|
||||
-tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
|
||||
tf.zeros([n_data,1]),
|
||||
tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
|
||||
tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
|
||||
tf.zeros([n_data,3]),
|
||||
tf.ones([n_data,1])],
|
||||
axis = 1
|
||||
)
|
||||
|
||||
rotation_X = tf.reshape(rotation_X,[n_data,3,3])
|
||||
rotation_Y = tf.reshape(rotation_Y,[n_data,3,3])
|
||||
rotation_Z = tf.reshape(rotation_Z,[n_data,3,3])
|
||||
|
||||
# R = RzRyRx
|
||||
rotation = tf.matmul(tf.matmul(rotation_Z,rotation_Y),rotation_X)
|
||||
|
||||
# because our face shape is N*3, so compute the transpose of R, so that rotation shapes can be calculated as face_shape*R
|
||||
rotation = tf.transpose(rotation, perm = [0,2,1])
|
||||
|
||||
return rotation
|
||||
|
||||
def Projection_block(self,face_shape,focal=1015.0,half_image_width=112.):
|
||||
|
||||
# pre-defined camera focal for pespective projection
|
||||
focal = tf.constant(focal)
|
||||
# focal = tf.constant(400.0)
|
||||
focal = tf.reshape(focal,[-1,1])
|
||||
batchsize = tf.shape(face_shape)[0]
|
||||
# center = tf.constant(112.0)
|
||||
|
||||
# define camera position
|
||||
camera_pos = tf.reshape(tf.constant([0.0,0.0,10.0]),[1,1,3])
|
||||
|
||||
# compute projection matrix
|
||||
p_matrix = tf.concat([focal*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),\
|
||||
focal*tf.ones([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,2]),tf.ones([batchsize,1])],axis = 1)
|
||||
# p_matrix = tf.tile(tf.reshape(p_matrix,[1,3,3]),[tf.shape(face_shape)[0],1,1])
|
||||
p_matrix = tf.reshape(p_matrix,[-1,3,3])
|
||||
|
||||
# convert z in canonical space to the distance to camera
|
||||
reverse_z = tf.tile(tf.reshape(tf.constant([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]),[tf.shape(face_shape)[0],1,1])
|
||||
face_shape = tf.matmul(face_shape,reverse_z) + camera_pos
|
||||
aug_projection = tf.matmul(face_shape,tf.transpose(p_matrix,[0,2,1]))
|
||||
|
||||
# [batchsize, N,2] 2d face projection
|
||||
face_projection = aug_projection[:,:,0:2]/tf.reshape(aug_projection[:,:,2],[tf.shape(face_shape)[0],tf.shape(aug_projection)[1],1])
|
||||
|
||||
|
||||
return face_projection
|
||||
|
||||
|
||||
def Compute_landmark(self,face_shape,facemodel):
|
||||
|
||||
# compute 3D landmark postitions with pre-computed 3D face shape
|
||||
keypoints_idx = facemodel.keypoints
|
||||
keypoints_idx = tf.cast(keypoints_idx - 1,tf.int32)
|
||||
face_landmark = tf.gather(face_shape,keypoints_idx,axis = 1)
|
||||
|
||||
return face_landmark
|
||||
|
||||
def Illumination_block(self,face_texture,norm_r,gamma):
|
||||
n_data = tf.shape(gamma)[0]
|
||||
n_point = tf.shape(norm_r)[1]
|
||||
gamma = tf.reshape(gamma,[n_data,3,9])
|
||||
# set initial lighting with an ambient lighting
|
||||
init_lit = tf.constant([0.8,0,0,0,0,0,0,0,0])
|
||||
gamma = gamma + tf.reshape(init_lit,[1,1,9])
|
||||
|
||||
# compute vertex color using SH function approximation
|
||||
a0 = m.pi
|
||||
a1 = 2*m.pi/tf.sqrt(3.0)
|
||||
a2 = 2*m.pi/tf.sqrt(8.0)
|
||||
c0 = 1/tf.sqrt(4*m.pi)
|
||||
c1 = tf.sqrt(3.0)/tf.sqrt(4*m.pi)
|
||||
c2 = 3*tf.sqrt(5.0)/tf.sqrt(12*m.pi)
|
||||
|
||||
Y = tf.concat([tf.tile(tf.reshape(a0*c0,[1,1,1]),[n_data,n_point,1]),
|
||||
tf.expand_dims(-a1*c1*norm_r[:,:,1],2),
|
||||
tf.expand_dims(a1*c1*norm_r[:,:,2],2),
|
||||
tf.expand_dims(-a1*c1*norm_r[:,:,0],2),
|
||||
tf.expand_dims(a2*c2*norm_r[:,:,0]*norm_r[:,:,1],2),
|
||||
tf.expand_dims(-a2*c2*norm_r[:,:,1]*norm_r[:,:,2],2),
|
||||
tf.expand_dims(a2*c2*0.5/tf.sqrt(3.0)*(3*tf.square(norm_r[:,:,2])-1),2),
|
||||
tf.expand_dims(-a2*c2*norm_r[:,:,0]*norm_r[:,:,2],2),
|
||||
tf.expand_dims(a2*c2*0.5*(tf.square(norm_r[:,:,0])-tf.square(norm_r[:,:,1])),2)],axis = 2)
|
||||
|
||||
color_r = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,0,:],2)),axis = 2)
|
||||
color_g = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,1,:],2)),axis = 2)
|
||||
color_b = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,2,:],2)),axis = 2)
|
||||
|
||||
#[batchsize,N,3] vertex color in RGB order
|
||||
face_color = tf.stack([color_r*face_texture[:,:,0],color_g*face_texture[:,:,1],color_b*face_texture[:,:,2]],axis = 2)
|
||||
|
||||
return face_color
|
||||
|
||||
def Rigid_transform_block(self,face_shape,rotation,translation):
|
||||
# do rigid transformation for 3D face shape
|
||||
face_shape_r = tf.matmul(face_shape,rotation)
|
||||
face_shape_t = face_shape_r + tf.reshape(translation,[tf.shape(face_shape)[0],1,3])
|
||||
|
||||
return face_shape_t
|
||||
|
||||
def Render_block(self,face_shape,face_norm,face_color,facemodel,batchsize):
|
||||
# render reconstruction images
|
||||
n_vex = int(facemodel.idBase.shape[0].value/3)
|
||||
fov_y = 2*tf.atan(112/(1015.))*180./m.pi + tf.zeros([batchsize])
|
||||
|
||||
# full face region
|
||||
face_shape = tf.reshape(face_shape,[batchsize,n_vex,3])
|
||||
face_norm = tf.reshape(face_norm,[batchsize,n_vex,3])
|
||||
face_color = tf.reshape(face_color,[batchsize,n_vex,3])
|
||||
|
||||
#cammera settings
|
||||
# same as in Projection_block
|
||||
camera_position = tf.constant([[0,0,10.0]]) + tf.zeros([batchsize,3])
|
||||
camera_lookat = tf.constant([[0,0,0.0]]) + tf.zeros([batchsize,3])
|
||||
camera_up = tf.constant([[0,1.0,0]]) + tf.zeros([batchsize,3])
|
||||
|
||||
# setting light source position(intensities are set to 0 because we have already computed the vertex color)
|
||||
light_positions = tf.reshape(tf.constant([0,0,1e5]),[1,1,3]) + tf.zeros([batchsize,1,3])
|
||||
light_intensities = tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3])+tf.zeros([batchsize,1,3])
|
||||
ambient_color = tf.reshape(tf.constant([1.0,1,1]),[1,3])+ tf.zeros([batchsize,3])
|
||||
|
||||
near_clip = 0.01*tf.ones([batchsize])
|
||||
far_clip = 50*tf.ones([batchsize])
|
||||
#using tf_mesh_renderer for rasterization (https://github.com/google/tf_mesh_renderer)
|
||||
# img: [batchsize,224,224,4] images in RGBA order (0-255)
|
||||
with tf.device('/cpu:0'):
|
||||
img = mesh_renderer.mesh_renderer(face_shape,
|
||||
tf.cast(facemodel.face_buf-1,tf.int32),
|
||||
face_norm,
|
||||
face_color,
|
||||
camera_position = camera_position,
|
||||
camera_lookat = camera_lookat,
|
||||
camera_up = camera_up,
|
||||
light_positions = light_positions,
|
||||
light_intensities = light_intensities,
|
||||
image_width = 224,
|
||||
image_height = 224,
|
||||
fov_y = fov_y, #12.5936
|
||||
ambient_color = ambient_color,
|
||||
near_clip = near_clip,
|
||||
far_clip = far_clip)
|
||||
|
||||
return img
|
||||
|
||||
|
После Ширина: | Высота: | Размер: 31 KiB |
|
@ -0,0 +1,5 @@
|
|||
142.84 207.18
|
||||
222.02 203.9
|
||||
159.24 253.57
|
||||
146.59 290.93
|
||||
227.52 284.74
|
После Ширина: | Высота: | Размер: 30 KiB |
|
@ -0,0 +1,5 @@
|
|||
199.93 158.28
|
||||
255.34 166.54
|
||||
236.08 198.92
|
||||
198.83 229.24
|
||||
245.23 234.52
|
После Ширина: | Высота: | Размер: 27 KiB |
|
@ -0,0 +1,5 @@
|
|||
129.36 198.28
|
||||
204.47 191.47
|
||||
164.42 240.51
|
||||
140.74 277.77
|
||||
205.4 270.9
|
После Ширина: | Высота: | Размер: 30 KiB |
|
@ -0,0 +1,5 @@
|
|||
151.23 240.71
|
||||
274.05 235.52
|
||||
217.37 305.99
|
||||
158.03 346.06
|
||||
272.17 341.09
|
После Ширина: | Высота: | Размер: 9.4 KiB |
|
@ -0,0 +1,5 @@
|
|||
119.09 94.291
|
||||
158.31 96.472
|
||||
136.76 121.4
|
||||
119.33 134.49
|
||||
154.66 136.68
|
После Ширина: | Высота: | Размер: 14 KiB |
|
@ -0,0 +1,5 @@
|
|||
147.37 159.39
|
||||
196.94 163.26
|
||||
190.68 194.36
|
||||
153.72 228.44
|
||||
193.94 229.7
|
После Ширина: | Высота: | Размер: 8.7 KiB |
|
@ -0,0 +1,5 @@
|
|||
150.4 94.799
|
||||
205.14 102.07
|
||||
179.54 131.16
|
||||
144.45 147.42
|
||||
193.39 154.14
|
После Ширина: | Высота: | Размер: 17 KiB |
|
@ -0,0 +1,5 @@
|
|||
114.26 193.42
|
||||
205.8 190.27
|
||||
154.15 244.02
|
||||
124.69 295.22
|
||||
200.88 292.69
|
После Ширина: | Высота: | Размер: 42 KiB |
|
@ -0,0 +1,5 @@
|
|||
217.52 152.95
|
||||
281.48 147.14
|
||||
253.02 196.03
|
||||
225.79 221.6
|
||||
288.25 214.44
|
После Ширина: | Высота: | Размер: 13 KiB |
|
@ -0,0 +1,5 @@
|
|||
90.928 99.858
|
||||
146.87 100.33
|
||||
114.22 130.36
|
||||
91.579 153.32
|
||||
143.63 153.56
|
После Ширина: | Высота: | Размер: 73 KiB |
|
@ -0,0 +1,5 @@
|
|||
307.56 166.54
|
||||
387.06 159.62
|
||||
335.52 222.26
|
||||
319.3 248.85
|
||||
397.71 239.14
|
После Ширина: | Высота: | Размер: 20 KiB |
|
@ -0,0 +1,5 @@
|
|||
226.38 193.65
|
||||
319.12 208.97
|
||||
279.99 245.88
|
||||
213.79 290.55
|
||||
303.03 302.1
|
После Ширина: | Высота: | Размер: 64 KiB |
|
@ -0,0 +1,5 @@
|
|||
208.4 410.08
|
||||
364.41 388.68
|
||||
291.6 503.57
|
||||
244.82 572.86
|
||||
383.18 553.49
|
После Ширина: | Высота: | Размер: 134 KiB |
|
@ -0,0 +1,5 @@
|
|||
284.61 496.57
|
||||
562.77 550.78
|
||||
395.85 712.84
|
||||
238.92 786.8
|
||||
495.61 827.22
|
После Ширина: | Высота: | Размер: 20 KiB |
|
@ -0,0 +1,5 @@
|
|||
153.95 153.43
|
||||
211.13 161.54
|
||||
197.28 190.26
|
||||
150.82 215.98
|
||||
202.32 223.12
|
После Ширина: | Высота: | Размер: 59 KiB |
|
@ -0,0 +1,5 @@
|
|||
481.31 396.88
|
||||
667.75 392.43
|
||||
557.81 440.55
|
||||
490.44 586.28
|
||||
640.56 583.2
|
После Ширина: | Высота: | Размер: 33 KiB |
|
@ -0,0 +1,5 @@
|
|||
191.79 143.97
|
||||
271.86 151.23
|
||||
191.25 210.29
|
||||
187.82 257.12
|
||||
258.82 261.96
|
Двоичные данные
input/vd005.png
До Ширина: | Высота: | Размер: 394 KiB |
|
@ -1,5 +0,0 @@
|
|||
287 299
|
||||
502 311
|
||||
393 411
|
||||
304 542
|
||||
476 547
|
|
@ -1,5 +1,5 @@
|
|||
118 115
|
||||
178 124
|
||||
127 147
|
||||
117 181
|
||||
167 191
|
||||
123.12 117.58
|
||||
176.59 122.09
|
||||
126.99 144.68
|
||||
117.61 183.43
|
||||
163.94 186.41
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
184 115
|
||||
261 98
|
||||
230 156
|
||||
204 205
|
||||
281 187
|
||||
180.12 116.13
|
||||
263.18 98.397
|
||||
230.48 154.72
|
||||
201.37 199.01
|
||||
279.18 182.56
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
171 258
|
||||
284 259
|
||||
211 330
|
||||
173 383
|
||||
274 389
|
||||
171.27 263.54
|
||||
286.58 263.88
|
||||
203.35 333.02
|
||||
170.6 389.42
|
||||
281.73 386.84
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
131 168
|
||||
192 154
|
||||
155 194
|
||||
148 238
|
||||
201 226
|
||||
136.01 167.83
|
||||
195.25 151.71
|
||||
152.89 191.45
|
||||
149.85 235.5
|
||||
201.16 222.8
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
162 292
|
||||
255 285
|
||||
213 349
|
||||
178 390
|
||||
252 382
|
||||
161.92 292.04
|
||||
254.21 283.81
|
||||
212.75 342.06
|
||||
170.78 387.28
|
||||
254.6 379.82
|
||||
|
|
Двоичные данные
input/vd057.png
До Ширина: | Высота: | Размер: 281 KiB |
|
@ -1,5 +0,0 @@
|
|||
252 178
|
||||
403 151
|
||||
332 221
|
||||
281 329
|
||||
405 312
|
|
@ -1,5 +1,5 @@
|
|||
275 290
|
||||
383 295
|
||||
312 356
|
||||
278 407
|
||||
359 412
|
||||
276.53 290.35
|
||||
383.38 294.75
|
||||
314.48 354.66
|
||||
275.08 407.72
|
||||
364.94 411.48
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
112 150
|
||||
160 148
|
||||
136 182
|
||||
123 202
|
||||
159 201
|
||||
108.59 149.07
|
||||
157.35 143.85
|
||||
134.4 173.2
|
||||
117.88 200.79
|
||||
159.56 196.36
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
124 227
|
||||
192 220
|
||||
164 274
|
||||
136 306
|
||||
187 298
|
||||
121.62 225.96
|
||||
186.73 223.07
|
||||
162.99 269.82
|
||||
132.12 302.62
|
||||
186.42 299.21
|
||||
|
|
14
load_data.py
|
@ -3,20 +3,6 @@ from PIL import Image
|
|||
from scipy.io import loadmat,savemat
|
||||
from array import array
|
||||
|
||||
# define facemodel for reconstruction
|
||||
class BFM():
|
||||
def __init__(self):
|
||||
model_path = './BFM/BFM_model_front.mat'
|
||||
model = loadmat(model_path)
|
||||
self.meanshape = model['meanshape'] # mean face shape
|
||||
self.idBase = model['idBase'] # identity basis
|
||||
self.exBase = model['exBase'] # expression basis
|
||||
self.meantex = model['meantex'] # mean face texture
|
||||
self.texBase = model['texBase'] # texture basis
|
||||
self.point_buf = model['point_buf'] # adjacent face index for each vertex, starts from 1 (only used for calculating face normal)
|
||||
self.tri = model['tri'] # vertex index for each triangle face, starts from 1
|
||||
self.keypoints = np.squeeze(model['keypoints']).astype(np.int32) - 1 # 68 face landmark index, starts from 0
|
||||
|
||||
# load expression basis
|
||||
def LoadExpBasis():
|
||||
n_vertex = 53215
|
||||
|
|
|
@ -27,25 +27,23 @@ def POS(xp,x):
|
|||
|
||||
return t,s
|
||||
|
||||
def process_img(img,lm,t,s):
|
||||
def process_img(img,lm,t,s,target_size = 224.):
|
||||
w0,h0 = img.size
|
||||
img = img.transform(img.size, Image.AFFINE, (1, 0, t[0] - w0/2, 0, 1, h0/2 - t[1]))
|
||||
w = (w0/s*102).astype(np.int32)
|
||||
h = (h0/s*102).astype(np.int32)
|
||||
img = img.resize((w,h),resample = Image.BILINEAR)
|
||||
lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102
|
||||
img = img.resize((w,h),resample = Image.BICUBIC)
|
||||
|
||||
# crop the image to 224*224 from image center
|
||||
left = (w/2 - 112).astype(np.int32)
|
||||
right = left + 224
|
||||
up = (h/2 - 112).astype(np.int32)
|
||||
below = up + 224
|
||||
left = (w/2 - target_size/2 + float((t[0] - w0/2)*102/s)).astype(np.int32)
|
||||
right = left + target_size
|
||||
up = (h/2 - target_size/2 + float((h0/2 - t[1])*102/s)).astype(np.int32)
|
||||
below = up + target_size
|
||||
|
||||
img = img.crop((left,up,right,below))
|
||||
img = np.array(img)
|
||||
img = img[:,:,::-1]
|
||||
img = img[:,:,::-1] #RGBtoBGR
|
||||
img = np.expand_dims(img,0)
|
||||
lm = lm - np.reshape(np.array([(w/2 - 112),(h/2-112)]),[1,2])
|
||||
lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102
|
||||
lm = lm - np.reshape(np.array([(w/2 - target_size/2),(h/2-target_size/2)]),[1,2])
|
||||
|
||||
return img,lm
|
||||
|
||||
|
|
|
@ -1,204 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
# input: coeff with shape [1,257]
|
||||
def Split_coeff(coeff):
|
||||
id_coeff = coeff[:,:80] # identity(shape) coeff of dim 80
|
||||
ex_coeff = coeff[:,80:144] # expression coeff of dim 64
|
||||
tex_coeff = coeff[:,144:224] # texture(albedo) coeff of dim 80
|
||||
angles = coeff[:,224:227] # ruler angles(x,y,z) for rotation of dim 3
|
||||
gamma = coeff[:,227:254] # lighting coeff for 3 channel SH function of dim 27
|
||||
translation = coeff[:,254:] # translation coeff of dim 3
|
||||
|
||||
return id_coeff,ex_coeff,tex_coeff,angles,gamma,translation
|
||||
|
||||
|
||||
# compute face shape with identity and expression coeff, based on BFM model
|
||||
# input: id_coeff with shape [1,80]
|
||||
# ex_coeff with shape [1,64]
|
||||
# output: face_shape with shape [1,N,3], N is number of vertices
|
||||
def Shape_formation(id_coeff,ex_coeff,facemodel):
|
||||
face_shape = np.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \
|
||||
np.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + \
|
||||
facemodel.meanshape
|
||||
|
||||
face_shape = np.reshape(face_shape,[1,-1,3])
|
||||
# re-center face shape
|
||||
face_shape = face_shape - np.mean(np.reshape(facemodel.meanshape,[1,-1,3]), axis = 1, keepdims = True)
|
||||
|
||||
return face_shape
|
||||
|
||||
# compute vertex normal using one-ring neighborhood
|
||||
# input: face_shape with shape [1,N,3]
|
||||
# output: v_norm with shape [1,N,3]
|
||||
def Compute_norm(face_shape,facemodel):
|
||||
|
||||
face_id = facemodel.tri # vertex index for each triangle face, with shape [F,3], F is number of faces
|
||||
point_id = facemodel.point_buf # adjacent face index for each vertex, with shape [N,8], N is number of vertex
|
||||
shape = face_shape
|
||||
face_id = (face_id - 1).astype(np.int32)
|
||||
point_id = (point_id - 1).astype(np.int32)
|
||||
v1 = shape[:,face_id[:,0],:]
|
||||
v2 = shape[:,face_id[:,1],:]
|
||||
v3 = shape[:,face_id[:,2],:]
|
||||
e1 = v1 - v2
|
||||
e2 = v2 - v3
|
||||
face_norm = np.cross(e1,e2) # compute normal for each face
|
||||
face_norm = np.concatenate([face_norm,np.zeros([1,1,3])], axis = 1) # concat face_normal with a zero vector at the end
|
||||
v_norm = np.sum(face_norm[:,point_id,:], axis = 2) # compute vertex normal using one-ring neighborhood
|
||||
v_norm = v_norm/np.expand_dims(np.linalg.norm(v_norm,axis = 2),2) # normalize normal vectors
|
||||
|
||||
return v_norm
|
||||
|
||||
# compute vertex texture(albedo) with tex_coeff
|
||||
# input: tex_coeff with shape [1,N,3]
|
||||
# output: face_texture with shape [1,N,3], RGB order, range from 0-255
|
||||
def Texture_formation(tex_coeff,facemodel):
|
||||
|
||||
face_texture = np.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex
|
||||
face_texture = np.reshape(face_texture,[1,-1,3])
|
||||
|
||||
return face_texture
|
||||
|
||||
# compute rotation matrix based on 3 ruler angles
|
||||
# input: angles with shape [1,3]
|
||||
# output: rotation matrix with shape [1,3,3]
|
||||
def Compute_rotation_matrix(angles):
|
||||
|
||||
angle_x = angles[:,0][0]
|
||||
angle_y = angles[:,1][0]
|
||||
angle_z = angles[:,2][0]
|
||||
|
||||
# compute rotation matrix for X,Y,Z axis respectively
|
||||
rotation_X = np.array([1.0,0,0,\
|
||||
0,np.cos(angle_x),-np.sin(angle_x),\
|
||||
0,np.sin(angle_x),np.cos(angle_x)])
|
||||
rotation_Y = np.array([np.cos(angle_y),0,np.sin(angle_y),\
|
||||
0,1,0,\
|
||||
-np.sin(angle_y),0,np.cos(angle_y)])
|
||||
rotation_Z = np.array([np.cos(angle_z),-np.sin(angle_z),0,\
|
||||
np.sin(angle_z),np.cos(angle_z),0,\
|
||||
0,0,1])
|
||||
|
||||
rotation_X = np.reshape(rotation_X,[1,3,3])
|
||||
rotation_Y = np.reshape(rotation_Y,[1,3,3])
|
||||
rotation_Z = np.reshape(rotation_Z,[1,3,3])
|
||||
|
||||
rotation = np.matmul(np.matmul(rotation_Z,rotation_Y),rotation_X)
|
||||
rotation = np.transpose(rotation, axes = [0,2,1]) #transpose row and column (dimension 1 and 2)
|
||||
|
||||
return rotation
|
||||
|
||||
# project 3D face onto image plane
|
||||
# input: face_shape with shape [1,N,3]
|
||||
# rotation with shape [1,3,3]
|
||||
# translation with shape [1,3]
|
||||
# output: face_projection with shape [1,N,2]
|
||||
# z_buffer with shape [1,N,1]
|
||||
def Projection_layer(face_shape,rotation,translation,focal=1015.0,center=112.0): # we choose the focal length and camera position empirically
|
||||
|
||||
camera_pos = np.reshape(np.array([0.0,0.0,10.0]),[1,1,3]) # camera position
|
||||
reverse_z = np.reshape(np.array([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3])
|
||||
|
||||
|
||||
p_matrix = np.concatenate([[focal],[0.0],[center],[0.0],[focal],[center],[0.0],[0.0],[1.0]],axis = 0) # projection matrix
|
||||
p_matrix = np.reshape(p_matrix,[1,3,3])
|
||||
|
||||
# calculate face position in camera space
|
||||
face_shape_r = np.matmul(face_shape,rotation)
|
||||
face_shape_t = face_shape_r + np.reshape(translation,[1,1,3])
|
||||
face_shape_t = np.matmul(face_shape_t,reverse_z) + camera_pos
|
||||
|
||||
# calculate projection of face vertex using perspective projection
|
||||
aug_projection = np.matmul(face_shape_t,np.transpose(p_matrix,[0,2,1]))
|
||||
face_projection = aug_projection[:,:,0:2]/np.reshape(aug_projection[:,:,2],[1,np.shape(aug_projection)[1],1])
|
||||
z_buffer = np.reshape(aug_projection[:,:,2],[1,-1,1])
|
||||
|
||||
return face_projection,z_buffer
|
||||
|
||||
# compute vertex color using face_texture and SH function lighting approximation
|
||||
# input: face_texture with shape [1,N,3]
|
||||
# norm with shape [1,N,3]
|
||||
# gamma with shape [1,27]
|
||||
# output: face_color with shape [1,N,3], RGB order, range from 0-255
|
||||
# lighting with shape [1,N,3], color under uniform texture
|
||||
def Illumination_layer(face_texture,norm,gamma):
|
||||
|
||||
num_vertex = np.shape(face_texture)[1]
|
||||
|
||||
init_lit = np.array([0.8,0,0,0,0,0,0,0,0])
|
||||
gamma = np.reshape(gamma,[-1,3,9])
|
||||
gamma = gamma + np.reshape(init_lit,[1,1,9])
|
||||
|
||||
# parameter of 9 SH function
|
||||
a0 = np.pi
|
||||
a1 = 2*np.pi/np.sqrt(3.0)
|
||||
a2 = 2*np.pi/np.sqrt(8.0)
|
||||
c0 = 1/np.sqrt(4*np.pi)
|
||||
c1 = np.sqrt(3.0)/np.sqrt(4*np.pi)
|
||||
c2 = 3*np.sqrt(5.0)/np.sqrt(12*np.pi)
|
||||
|
||||
Y0 = np.tile(np.reshape(a0*c0,[1,1,1]),[1,num_vertex,1])
|
||||
Y1 = np.reshape(-a1*c1*norm[:,:,1],[1,num_vertex,1])
|
||||
Y2 = np.reshape(a1*c1*norm[:,:,2],[1,num_vertex,1])
|
||||
Y3 = np.reshape(-a1*c1*norm[:,:,0],[1,num_vertex,1])
|
||||
Y4 = np.reshape(a2*c2*norm[:,:,0]*norm[:,:,1],[1,num_vertex,1])
|
||||
Y5 = np.reshape(-a2*c2*norm[:,:,1]*norm[:,:,2],[1,num_vertex,1])
|
||||
Y6 = np.reshape(a2*c2*0.5/np.sqrt(3.0)*(3*np.square(norm[:,:,2])-1),[1,num_vertex,1])
|
||||
Y7 = np.reshape(-a2*c2*norm[:,:,0]*norm[:,:,2],[1,num_vertex,1])
|
||||
Y8 = np.reshape(a2*c2*0.5*(np.square(norm[:,:,0])-np.square(norm[:,:,1])),[1,num_vertex,1])
|
||||
|
||||
Y = np.concatenate([Y0,Y1,Y2,Y3,Y4,Y5,Y6,Y7,Y8],axis=2)
|
||||
|
||||
# Y shape:[batch,N,9].
|
||||
|
||||
lit_r = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,0,:],2)),2) #[batch,N,9] * [batch,9,1] = [batch,N]
|
||||
lit_g = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,1,:],2)),2)
|
||||
lit_b = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,2,:],2)),2)
|
||||
|
||||
# shape:[batch,N,3]
|
||||
face_color = np.stack([lit_r*face_texture[:,:,0],lit_g*face_texture[:,:,1],lit_b*face_texture[:,:,2]],axis = 2)
|
||||
lighting = np.stack([lit_r,lit_g,lit_b],axis = 2)*128
|
||||
|
||||
return face_color,lighting
|
||||
|
||||
# face reconstruction with coeff and BFM model
|
||||
def Reconstruction(coeff,facemodel):
|
||||
id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff)
|
||||
# compute face shape
|
||||
face_shape = Shape_formation(id_coeff, ex_coeff, facemodel)
|
||||
# compute vertex texture(albedo)
|
||||
face_texture = Texture_formation(tex_coeff, facemodel)
|
||||
# vertex normal
|
||||
face_norm = Compute_norm(face_shape,facemodel)
|
||||
# rotation matrix
|
||||
rotation = Compute_rotation_matrix(angles)
|
||||
face_norm_r = np.matmul(face_norm,rotation)
|
||||
|
||||
# compute vertex projection on image plane (with image sized 224*224)
|
||||
face_projection,z_buffer = Projection_layer(face_shape,rotation,translation)
|
||||
face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2)
|
||||
|
||||
# compute 68 landmark on image plane
|
||||
landmarks_2d = face_projection[:,facemodel.keypoints,:]
|
||||
|
||||
# compute vertex color using SH function lighting approximation
|
||||
face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma)
|
||||
|
||||
# vertex index for each face of BFM model
|
||||
tri = facemodel.tri
|
||||
|
||||
return face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d
|
||||
|
||||
# def Reconstruction_for_render(coeff,facemodel):
|
||||
# id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff)
|
||||
# face_shape = Shape_formation(id_coeff, ex_coeff, facemodel)
|
||||
# face_texture = Texture_formation(tex_coeff, facemodel)
|
||||
# face_norm = Compute_norm(face_shape,facemodel)
|
||||
# rotation = Compute_rotation_matrix(angles)
|
||||
# face_shape_r = np.matmul(face_shape,rotation)
|
||||
# face_shape_r = face_shape_r + np.reshape(translation,[1,1,3])
|
||||
# face_norm_r = np.matmul(face_norm,rotation)
|
||||
# face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma)
|
||||
# tri = facemodel.face_buf
|
||||
|
||||
# return face_shape_r,face_norm_r,face_color,tri
|