transfer reconstruction process from numpy to tensorflow, add rendering process, update image preprocessing method

This commit is contained in:
dengy 2020-04-08 15:32:34 +08:00
Родитель 451e82e5ab
Коммит 9d3bc46e43
51 изменённых файлов: 455 добавлений и 295 удалений

44
demo.py
Просмотреть файл

@ -8,7 +8,7 @@ from scipy.io import loadmat,savemat
from preprocess_img import Preprocess
from load_data import *
from reconstruct_mesh import Reconstruction
from face_decoder import Face3D
def load_graph(graph_filename):
with tf.gfile.GFile(graph_filename,'rb') as f:
@ -21,56 +21,68 @@ def demo():
# input and output folder
image_path = 'input'
save_path = 'output'
if not os.path.exists(save_path):
os.makedirs(save_path)
img_list = glob.glob(image_path + '/' + '*.png')
img_list +=glob.glob(image_path + '/' + '*.jpg')
# read BFM face model
# transfer original BFM model to our model
if not os.path.isfile('./BFM/BFM_model_front.mat'):
transferBFM09()
# read face model
facemodel = BFM()
# read standard landmarks for preprocessing images
lm3D = load_lm3d()
batchsize = 1
n = 0
# build reconstruction model
with tf.Graph().as_default() as graph,tf.device('/cpu:0'):
images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32)
FaceReconstructor = Face3D()
images = tf.placeholder(name = 'input_imgs', shape = [batchsize,224,224,3], dtype = tf.float32)
graph_def = load_graph('network/FaceReconModel.pb')
tf.import_graph_def(graph_def,name='resnet',input_map={'input_imgs:0': images})
# output coefficients of R-Net (dim = 257)
coeff = graph.get_tensor_by_name('resnet/coeff:0')
# reconstructing faces
FaceReconstructor.Reconstruction_Block(coeff,batchsize)
face_shape = FaceReconstructor.face_shape_t
face_texture = FaceReconstructor.face_texture
face_color = FaceReconstructor.face_color
landmarks_2d = FaceReconstructor.landmark_p
recon_img = FaceReconstructor.render_imgs
tri = FaceReconstructor.facemodel.face_buf
with tf.Session() as sess:
print('reconstructing...')
for file in img_list:
n += 1
print(n)
# load images and corresponding 5 facial landmarks
img,lm = load_img(file,file.replace('png','txt'))
img,lm = load_img(file,file.replace('png','txt').replace('jpg','txt'))
# preprocess input image
input_img,lm_new,transform_params = Preprocess(img,lm,lm3D)
coef = sess.run(coeff,feed_dict = {images: input_img})
coeff_,face_shape_,face_texture_,face_color_,landmarks_2d_,recon_img_,tri_ = sess.run([coeff,\
face_shape,face_texture,face_color,landmarks_2d,recon_img,tri],feed_dict = {images: input_img})
# reconstruct 3D face with output coefficients and face model
face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d = Reconstruction(coef,facemodel)
# reshape outputs
input_img = np.squeeze(input_img)
shape = np.squeeze(face_shape, (0))
color = np.squeeze(face_color, (0))
landmarks_2d = np.squeeze(landmarks_2d, (0))
face_shape_ = np.squeeze(face_shape_, (0))
face_texture_ = np.squeeze(face_texture_, (0))
face_color_ = np.squeeze(face_color_, (0))
landmarks_2d_ = np.squeeze(landmarks_2d_, (0))
recon_img_ = np.squeeze(recon_img_, (0))
# save output files
# cropped image, which is the direct input to our R-Net
# 257 dim output coefficients by R-Net
# 68 face landmarks of cropped image
savemat(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','.mat')),{'cropped_img':input_img[:,:,::-1],'coeff':coef,'landmarks_2d':landmarks_2d,'lm_5p':lm_new})
save_obj(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','_mesh.obj')),shape,tri,np.clip(color,0,255)/255) # 3D reconstruction face (in canonical view)
savemat(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','.mat').replace('jpg','mat')),{'cropped_img':input_img[:,:,::-1],'recon_img':recon_img_,'coeff':coeff_,\
'face_shape':face_shape_,'face_texture':face_texture_,'face_color':face_color_,'lm_68p':landmarks_2d_,'lm_5p':lm_new})
save_obj(os.path.join(save_path,file.split(os.path.sep)[-1].replace('.png','_mesh.obj').replace('jpg','_mesh.obj')),face_shape_,tri_,np.clip(face_color_,0,255)/255) # 3D reconstruction face (in canonical view)
if __name__ == '__main__':
demo()

293
face_decoder.py Normal file
Просмотреть файл

@ -0,0 +1,293 @@
import tensorflow as tf
import math as m
import numpy as np
import mesh_renderer
from scipy.io import loadmat
###############################################################################################
# Reconstruct 3D face based on output coefficients and facemodel
###############################################################################################
# BFM 3D face model
class BFM():
def __init__(self,model_path = 'BFM/BFM_model_front.mat'):
model = loadmat(model_path)
self.meanshape = tf.constant(model['meanshape']) # mean face shape. [3*N,1]
self.idBase = tf.constant(model['idBase']) # identity basis. [3*N,80]
self.exBase = tf.constant(model['exBase'].astype(np.float32)) # expression basis. [3*N,64]
self.meantex = tf.constant(model['meantex']) # mean face texture. [3*N,1] (0-255)
self.texBase = tf.constant(model['texBase']) # texture basis. [3*N,80]
self.point_buf = tf.constant(model['point_buf']) # triangle indices for each vertex that lies in. starts from 1. [N,8]
self.face_buf = tf.constant(model['tri']) # vertex indices in each triangle. starts from 1. [F,3]
self.keypoints = tf.squeeze(tf.constant(model['keypoints'])) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
# Analytic 3D face reconstructor
class Face3D():
def __init__(self):
facemodel = BFM()
self.facemodel = facemodel
# analytic 3D face reconstructions with coefficients from R-Net
def Reconstruction_Block(self,coeff,batchsize):
#coeff: [batchsize,257] reconstruction coefficients
id_coeff,ex_coeff,tex_coeff,angles,translation,gamma = self.Split_coeff(coeff)
# [batchsize,N,3] canonical face shape in BFM space
face_shape = self.Shape_formation_block(id_coeff,ex_coeff,self.facemodel)
# [batchsize,N,3] vertex texture (in RGB order)
face_texture = self.Texture_formation_block(tex_coeff,self.facemodel)
self.face_texture = face_texture
# [batchsize,3,3] rotation matrix for face shape
rotation = self.Compute_rotation_matrix(angles)
# [batchsize,N,3] vertex normal
face_norm = self.Compute_norm(face_shape,self.facemodel)
norm_r = tf.matmul(face_norm,rotation)
# do rigid transformation for face shape using predicted rotation and translation
face_shape_t = self.Rigid_transform_block(face_shape,rotation,translation)
self.face_shape_t = face_shape_t
# compute 2d landmark projections
# landmark_p: [batchsize,68,2]
face_landmark_t = self.Compute_landmark(face_shape_t,self.facemodel)
landmark_p = self.Projection_block(face_landmark_t) # 256*256 image
self.landmark_p = landmark_p
# [batchsize,N,3] vertex color (in RGB order)
face_color = self.Illumination_block(face_texture, norm_r, gamma)
self.face_color = face_color
# reconstruction images
render_imgs = self.Render_block(face_shape_t,norm_r,face_color,self.facemodel,batchsize)
render_imgs = tf.clip_by_value(render_imgs,0,255)
render_imgs = tf.cast(render_imgs,tf.float32)
self.render_imgs = render_imgs
######################################################################################################
def Split_coeff(self,coeff):
id_coeff = coeff[:,:80] #identity
ex_coeff = coeff[:,80:144] #expression
tex_coeff = coeff[:,144:224] #texture
angles = coeff[:,224:227] #euler angles for pose
gamma = coeff[:,227:254] #lighting
translation = coeff[:,254:257] #translation
return id_coeff,ex_coeff,tex_coeff,angles,translation,gamma
def Shape_formation_block(self,id_coeff,ex_coeff,facemodel):
face_shape = tf.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \
tf.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + facemodel.meanshape
# reshape face shape to [batchsize,N,3]
face_shape = tf.reshape(face_shape,[tf.shape(face_shape)[0],-1,3])
# re-centering the face shape with mean shape
face_shape = face_shape - tf.reshape(tf.reduce_mean(tf.reshape(facemodel.meanshape,[-1,3]),0),[1,1,3])
return face_shape
def Compute_norm(self,face_shape,facemodel):
shape = face_shape
face_id = facemodel.face_buf
point_id = facemodel.point_buf
# face_id and point_id index starts from 1
face_id = tf.cast(face_id - 1,tf.int32)
point_id = tf.cast(point_id - 1,tf.int32)
#compute normal for each face
v1 = tf.gather(shape,face_id[:,0], axis = 1)
v2 = tf.gather(shape,face_id[:,1], axis = 1)
v3 = tf.gather(shape,face_id[:,2], axis = 1)
e1 = v1 - v2
e2 = v2 - v3
face_norm = tf.cross(e1,e2)
face_norm = tf.nn.l2_normalize(face_norm, dim = 2) # normalized face_norm first
face_norm = tf.concat([face_norm,tf.zeros([tf.shape(face_shape)[0],1,3])], axis = 1)
#compute normal for each vertex using one-ring neighborhood
v_norm = tf.reduce_sum(tf.gather(face_norm, point_id, axis = 1), axis = 2)
v_norm = tf.nn.l2_normalize(v_norm, dim = 2)
return v_norm
def Texture_formation_block(self,tex_coeff,facemodel):
face_texture = tf.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex
# reshape face texture to [batchsize,N,3], note that texture is in RGB order
face_texture = tf.reshape(face_texture,[tf.shape(face_texture)[0],-1,3])
return face_texture
def Compute_rotation_matrix(self,angles):
n_data = tf.shape(angles)[0]
# compute rotation matrix for X-axis, Y-axis, Z-axis respectively
rotation_X = tf.concat([tf.ones([n_data,1]),
tf.zeros([n_data,3]),
tf.reshape(tf.cos(angles[:,0]),[n_data,1]),
-tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,0]),[n_data,1]),
tf.reshape(tf.cos(angles[:,0]),[n_data,1])],
axis = 1
)
rotation_Y = tf.concat([tf.reshape(tf.cos(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.ones([n_data,1]),
tf.zeros([n_data,1]),
-tf.reshape(tf.sin(angles[:,1]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.cos(angles[:,1]),[n_data,1])],
axis = 1
)
rotation_Z = tf.concat([tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
-tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
tf.zeros([n_data,1]),
tf.reshape(tf.sin(angles[:,2]),[n_data,1]),
tf.reshape(tf.cos(angles[:,2]),[n_data,1]),
tf.zeros([n_data,3]),
tf.ones([n_data,1])],
axis = 1
)
rotation_X = tf.reshape(rotation_X,[n_data,3,3])
rotation_Y = tf.reshape(rotation_Y,[n_data,3,3])
rotation_Z = tf.reshape(rotation_Z,[n_data,3,3])
# R = RzRyRx
rotation = tf.matmul(tf.matmul(rotation_Z,rotation_Y),rotation_X)
# because our face shape is N*3, so compute the transpose of R, so that rotation shapes can be calculated as face_shape*R
rotation = tf.transpose(rotation, perm = [0,2,1])
return rotation
def Projection_block(self,face_shape,focal=1015.0,half_image_width=112.):
# pre-defined camera focal for pespective projection
focal = tf.constant(focal)
# focal = tf.constant(400.0)
focal = tf.reshape(focal,[-1,1])
batchsize = tf.shape(face_shape)[0]
# center = tf.constant(112.0)
# define camera position
camera_pos = tf.reshape(tf.constant([0.0,0.0,10.0]),[1,1,3])
# compute projection matrix
p_matrix = tf.concat([focal*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,1]),\
focal*tf.ones([batchsize,1]),half_image_width*tf.ones([batchsize,1]),tf.zeros([batchsize,2]),tf.ones([batchsize,1])],axis = 1)
# p_matrix = tf.tile(tf.reshape(p_matrix,[1,3,3]),[tf.shape(face_shape)[0],1,1])
p_matrix = tf.reshape(p_matrix,[-1,3,3])
# convert z in canonical space to the distance to camera
reverse_z = tf.tile(tf.reshape(tf.constant([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]),[tf.shape(face_shape)[0],1,1])
face_shape = tf.matmul(face_shape,reverse_z) + camera_pos
aug_projection = tf.matmul(face_shape,tf.transpose(p_matrix,[0,2,1]))
# [batchsize, N,2] 2d face projection
face_projection = aug_projection[:,:,0:2]/tf.reshape(aug_projection[:,:,2],[tf.shape(face_shape)[0],tf.shape(aug_projection)[1],1])
return face_projection
def Compute_landmark(self,face_shape,facemodel):
# compute 3D landmark postitions with pre-computed 3D face shape
keypoints_idx = facemodel.keypoints
keypoints_idx = tf.cast(keypoints_idx - 1,tf.int32)
face_landmark = tf.gather(face_shape,keypoints_idx,axis = 1)
return face_landmark
def Illumination_block(self,face_texture,norm_r,gamma):
n_data = tf.shape(gamma)[0]
n_point = tf.shape(norm_r)[1]
gamma = tf.reshape(gamma,[n_data,3,9])
# set initial lighting with an ambient lighting
init_lit = tf.constant([0.8,0,0,0,0,0,0,0,0])
gamma = gamma + tf.reshape(init_lit,[1,1,9])
# compute vertex color using SH function approximation
a0 = m.pi
a1 = 2*m.pi/tf.sqrt(3.0)
a2 = 2*m.pi/tf.sqrt(8.0)
c0 = 1/tf.sqrt(4*m.pi)
c1 = tf.sqrt(3.0)/tf.sqrt(4*m.pi)
c2 = 3*tf.sqrt(5.0)/tf.sqrt(12*m.pi)
Y = tf.concat([tf.tile(tf.reshape(a0*c0,[1,1,1]),[n_data,n_point,1]),
tf.expand_dims(-a1*c1*norm_r[:,:,1],2),
tf.expand_dims(a1*c1*norm_r[:,:,2],2),
tf.expand_dims(-a1*c1*norm_r[:,:,0],2),
tf.expand_dims(a2*c2*norm_r[:,:,0]*norm_r[:,:,1],2),
tf.expand_dims(-a2*c2*norm_r[:,:,1]*norm_r[:,:,2],2),
tf.expand_dims(a2*c2*0.5/tf.sqrt(3.0)*(3*tf.square(norm_r[:,:,2])-1),2),
tf.expand_dims(-a2*c2*norm_r[:,:,0]*norm_r[:,:,2],2),
tf.expand_dims(a2*c2*0.5*(tf.square(norm_r[:,:,0])-tf.square(norm_r[:,:,1])),2)],axis = 2)
color_r = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,0,:],2)),axis = 2)
color_g = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,1,:],2)),axis = 2)
color_b = tf.squeeze(tf.matmul(Y,tf.expand_dims(gamma[:,2,:],2)),axis = 2)
#[batchsize,N,3] vertex color in RGB order
face_color = tf.stack([color_r*face_texture[:,:,0],color_g*face_texture[:,:,1],color_b*face_texture[:,:,2]],axis = 2)
return face_color
def Rigid_transform_block(self,face_shape,rotation,translation):
# do rigid transformation for 3D face shape
face_shape_r = tf.matmul(face_shape,rotation)
face_shape_t = face_shape_r + tf.reshape(translation,[tf.shape(face_shape)[0],1,3])
return face_shape_t
def Render_block(self,face_shape,face_norm,face_color,facemodel,batchsize):
# render reconstruction images
n_vex = int(facemodel.idBase.shape[0].value/3)
fov_y = 2*tf.atan(112/(1015.))*180./m.pi + tf.zeros([batchsize])
# full face region
face_shape = tf.reshape(face_shape,[batchsize,n_vex,3])
face_norm = tf.reshape(face_norm,[batchsize,n_vex,3])
face_color = tf.reshape(face_color,[batchsize,n_vex,3])
#cammera settings
# same as in Projection_block
camera_position = tf.constant([[0,0,10.0]]) + tf.zeros([batchsize,3])
camera_lookat = tf.constant([[0,0,0.0]]) + tf.zeros([batchsize,3])
camera_up = tf.constant([[0,1.0,0]]) + tf.zeros([batchsize,3])
# setting light source position(intensities are set to 0 because we have already computed the vertex color)
light_positions = tf.reshape(tf.constant([0,0,1e5]),[1,1,3]) + tf.zeros([batchsize,1,3])
light_intensities = tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3])+tf.zeros([batchsize,1,3])
ambient_color = tf.reshape(tf.constant([1.0,1,1]),[1,3])+ tf.zeros([batchsize,3])
near_clip = 0.01*tf.ones([batchsize])
far_clip = 50*tf.ones([batchsize])
#using tf_mesh_renderer for rasterization (https://github.com/google/tf_mesh_renderer)
# img: [batchsize,224,224,4] images in RGBA order (0-255)
with tf.device('/cpu:0'):
img = mesh_renderer.mesh_renderer(face_shape,
tf.cast(facemodel.face_buf-1,tf.int32),
face_norm,
face_color,
camera_position = camera_position,
camera_lookat = camera_lookat,
camera_up = camera_up,
light_positions = light_positions,
light_intensities = light_intensities,
image_width = 224,
image_height = 224,
fov_y = fov_y, #12.5936
ambient_color = ambient_color,
near_clip = near_clip,
far_clip = far_clip)
return img

Двоичные данные
input/000002.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 31 KiB

5
input/000002.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
142.84 207.18
222.02 203.9
159.24 253.57
146.59 290.93
227.52 284.74

Двоичные данные
input/000006.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 30 KiB

5
input/000006.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
199.93 158.28
255.34 166.54
236.08 198.92
198.83 229.24
245.23 234.52

Двоичные данные
input/000007.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 27 KiB

5
input/000007.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
129.36 198.28
204.47 191.47
164.42 240.51
140.74 277.77
205.4 270.9

Двоичные данные
input/000031.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 30 KiB

5
input/000031.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
151.23 240.71
274.05 235.52
217.37 305.99
158.03 346.06
272.17 341.09

Двоичные данные
input/000033.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 9.4 KiB

5
input/000033.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
119.09 94.291
158.31 96.472
136.76 121.4
119.33 134.49
154.66 136.68

Двоичные данные
input/000037.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 14 KiB

5
input/000037.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
147.37 159.39
196.94 163.26
190.68 194.36
153.72 228.44
193.94 229.7

Двоичные данные
input/000050.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 8.7 KiB

5
input/000050.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
150.4 94.799
205.14 102.07
179.54 131.16
144.45 147.42
193.39 154.14

Двоичные данные
input/000055.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 17 KiB

5
input/000055.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
114.26 193.42
205.8 190.27
154.15 244.02
124.69 295.22
200.88 292.69

Двоичные данные
input/000114.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 42 KiB

5
input/000114.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
217.52 152.95
281.48 147.14
253.02 196.03
225.79 221.6
288.25 214.44

Двоичные данные
input/000125.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 13 KiB

5
input/000125.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
90.928 99.858
146.87 100.33
114.22 130.36
91.579 153.32
143.63 153.56

Двоичные данные
input/000126.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 73 KiB

5
input/000126.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
307.56 166.54
387.06 159.62
335.52 222.26
319.3 248.85
397.71 239.14

Двоичные данные
input/015259.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 20 KiB

5
input/015259.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
226.38 193.65
319.12 208.97
279.99 245.88
213.79 290.55
303.03 302.1

Двоичные данные
input/015270.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 64 KiB

5
input/015270.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
208.4 410.08
364.41 388.68
291.6 503.57
244.82 572.86
383.18 553.49

Двоичные данные
input/015309.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 134 KiB

5
input/015309.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
284.61 496.57
562.77 550.78
395.85 712.84
238.92 786.8
495.61 827.22

Двоичные данные
input/015310.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 20 KiB

5
input/015310.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
153.95 153.43
211.13 161.54
197.28 190.26
150.82 215.98
202.32 223.12

Двоичные данные
input/015316.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 59 KiB

5
input/015316.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
481.31 396.88
667.75 392.43
557.81 440.55
490.44 586.28
640.56 583.2

Двоичные данные
input/015384.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 33 KiB

5
input/015384.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
191.79 143.97
271.86 151.23
191.25 210.29
187.82 257.12
258.82 261.96

Двоичные данные
input/vd005.png

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 394 KiB

Просмотреть файл

@ -1,5 +0,0 @@
287 299
502 311
393 411
304 542
476 547

Просмотреть файл

@ -1,5 +1,5 @@
118 115
178 124
127 147
117 181
167 191
123.12 117.58
176.59 122.09
126.99 144.68
117.61 183.43
163.94 186.41

Просмотреть файл

@ -1,5 +1,5 @@
184 115
261 98
230 156
204 205
281 187
180.12 116.13
263.18 98.397
230.48 154.72
201.37 199.01
279.18 182.56

Просмотреть файл

@ -1,5 +1,5 @@
171 258
284 259
211 330
173 383
274 389
171.27 263.54
286.58 263.88
203.35 333.02
170.6 389.42
281.73 386.84

Просмотреть файл

@ -1,5 +1,5 @@
131 168
192 154
155 194
148 238
201 226
136.01 167.83
195.25 151.71
152.89 191.45
149.85 235.5
201.16 222.8

Просмотреть файл

@ -1,5 +1,5 @@
162 292
255 285
213 349
178 390
252 382
161.92 292.04
254.21 283.81
212.75 342.06
170.78 387.28
254.6 379.82

Двоичные данные
input/vd057.png

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 281 KiB

Просмотреть файл

@ -1,5 +0,0 @@
252 178
403 151
332 221
281 329
405 312

Просмотреть файл

@ -1,5 +1,5 @@
275 290
383 295
312 356
278 407
359 412
276.53 290.35
383.38 294.75
314.48 354.66
275.08 407.72
364.94 411.48

Просмотреть файл

@ -1,5 +1,5 @@
112 150
160 148
136 182
123 202
159 201
108.59 149.07
157.35 143.85
134.4 173.2
117.88 200.79
159.56 196.36

Просмотреть файл

@ -1,5 +1,5 @@
124 227
192 220
164 274
136 306
187 298
121.62 225.96
186.73 223.07
162.99 269.82
132.12 302.62
186.42 299.21

Просмотреть файл

@ -3,20 +3,6 @@ from PIL import Image
from scipy.io import loadmat,savemat
from array import array
# define facemodel for reconstruction
class BFM():
def __init__(self):
model_path = './BFM/BFM_model_front.mat'
model = loadmat(model_path)
self.meanshape = model['meanshape'] # mean face shape
self.idBase = model['idBase'] # identity basis
self.exBase = model['exBase'] # expression basis
self.meantex = model['meantex'] # mean face texture
self.texBase = model['texBase'] # texture basis
self.point_buf = model['point_buf'] # adjacent face index for each vertex, starts from 1 (only used for calculating face normal)
self.tri = model['tri'] # vertex index for each triangle face, starts from 1
self.keypoints = np.squeeze(model['keypoints']).astype(np.int32) - 1 # 68 face landmark index, starts from 0
# load expression basis
def LoadExpBasis():
n_vertex = 53215

Просмотреть файл

@ -27,25 +27,23 @@ def POS(xp,x):
return t,s
def process_img(img,lm,t,s):
def process_img(img,lm,t,s,target_size = 224.):
w0,h0 = img.size
img = img.transform(img.size, Image.AFFINE, (1, 0, t[0] - w0/2, 0, 1, h0/2 - t[1]))
w = (w0/s*102).astype(np.int32)
h = (h0/s*102).astype(np.int32)
img = img.resize((w,h),resample = Image.BILINEAR)
lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102
img = img.resize((w,h),resample = Image.BICUBIC)
# crop the image to 224*224 from image center
left = (w/2 - 112).astype(np.int32)
right = left + 224
up = (h/2 - 112).astype(np.int32)
below = up + 224
left = (w/2 - target_size/2 + float((t[0] - w0/2)*102/s)).astype(np.int32)
right = left + target_size
up = (h/2 - target_size/2 + float((h0/2 - t[1])*102/s)).astype(np.int32)
below = up + target_size
img = img.crop((left,up,right,below))
img = np.array(img)
img = img[:,:,::-1]
img = img[:,:,::-1] #RGBtoBGR
img = np.expand_dims(img,0)
lm = lm - np.reshape(np.array([(w/2 - 112),(h/2-112)]),[1,2])
lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102
lm = lm - np.reshape(np.array([(w/2 - target_size/2),(h/2-target_size/2)]),[1,2])
return img,lm

Просмотреть файл

@ -1,204 +0,0 @@
import numpy as np
# input: coeff with shape [1,257]
def Split_coeff(coeff):
id_coeff = coeff[:,:80] # identity(shape) coeff of dim 80
ex_coeff = coeff[:,80:144] # expression coeff of dim 64
tex_coeff = coeff[:,144:224] # texture(albedo) coeff of dim 80
angles = coeff[:,224:227] # ruler angles(x,y,z) for rotation of dim 3
gamma = coeff[:,227:254] # lighting coeff for 3 channel SH function of dim 27
translation = coeff[:,254:] # translation coeff of dim 3
return id_coeff,ex_coeff,tex_coeff,angles,gamma,translation
# compute face shape with identity and expression coeff, based on BFM model
# input: id_coeff with shape [1,80]
# ex_coeff with shape [1,64]
# output: face_shape with shape [1,N,3], N is number of vertices
def Shape_formation(id_coeff,ex_coeff,facemodel):
face_shape = np.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \
np.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + \
facemodel.meanshape
face_shape = np.reshape(face_shape,[1,-1,3])
# re-center face shape
face_shape = face_shape - np.mean(np.reshape(facemodel.meanshape,[1,-1,3]), axis = 1, keepdims = True)
return face_shape
# compute vertex normal using one-ring neighborhood
# input: face_shape with shape [1,N,3]
# output: v_norm with shape [1,N,3]
def Compute_norm(face_shape,facemodel):
face_id = facemodel.tri # vertex index for each triangle face, with shape [F,3], F is number of faces
point_id = facemodel.point_buf # adjacent face index for each vertex, with shape [N,8], N is number of vertex
shape = face_shape
face_id = (face_id - 1).astype(np.int32)
point_id = (point_id - 1).astype(np.int32)
v1 = shape[:,face_id[:,0],:]
v2 = shape[:,face_id[:,1],:]
v3 = shape[:,face_id[:,2],:]
e1 = v1 - v2
e2 = v2 - v3
face_norm = np.cross(e1,e2) # compute normal for each face
face_norm = np.concatenate([face_norm,np.zeros([1,1,3])], axis = 1) # concat face_normal with a zero vector at the end
v_norm = np.sum(face_norm[:,point_id,:], axis = 2) # compute vertex normal using one-ring neighborhood
v_norm = v_norm/np.expand_dims(np.linalg.norm(v_norm,axis = 2),2) # normalize normal vectors
return v_norm
# compute vertex texture(albedo) with tex_coeff
# input: tex_coeff with shape [1,N,3]
# output: face_texture with shape [1,N,3], RGB order, range from 0-255
def Texture_formation(tex_coeff,facemodel):
face_texture = np.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex
face_texture = np.reshape(face_texture,[1,-1,3])
return face_texture
# compute rotation matrix based on 3 ruler angles
# input: angles with shape [1,3]
# output: rotation matrix with shape [1,3,3]
def Compute_rotation_matrix(angles):
angle_x = angles[:,0][0]
angle_y = angles[:,1][0]
angle_z = angles[:,2][0]
# compute rotation matrix for X,Y,Z axis respectively
rotation_X = np.array([1.0,0,0,\
0,np.cos(angle_x),-np.sin(angle_x),\
0,np.sin(angle_x),np.cos(angle_x)])
rotation_Y = np.array([np.cos(angle_y),0,np.sin(angle_y),\
0,1,0,\
-np.sin(angle_y),0,np.cos(angle_y)])
rotation_Z = np.array([np.cos(angle_z),-np.sin(angle_z),0,\
np.sin(angle_z),np.cos(angle_z),0,\
0,0,1])
rotation_X = np.reshape(rotation_X,[1,3,3])
rotation_Y = np.reshape(rotation_Y,[1,3,3])
rotation_Z = np.reshape(rotation_Z,[1,3,3])
rotation = np.matmul(np.matmul(rotation_Z,rotation_Y),rotation_X)
rotation = np.transpose(rotation, axes = [0,2,1]) #transpose row and column (dimension 1 and 2)
return rotation
# project 3D face onto image plane
# input: face_shape with shape [1,N,3]
# rotation with shape [1,3,3]
# translation with shape [1,3]
# output: face_projection with shape [1,N,2]
# z_buffer with shape [1,N,1]
def Projection_layer(face_shape,rotation,translation,focal=1015.0,center=112.0): # we choose the focal length and camera position empirically
camera_pos = np.reshape(np.array([0.0,0.0,10.0]),[1,1,3]) # camera position
reverse_z = np.reshape(np.array([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3])
p_matrix = np.concatenate([[focal],[0.0],[center],[0.0],[focal],[center],[0.0],[0.0],[1.0]],axis = 0) # projection matrix
p_matrix = np.reshape(p_matrix,[1,3,3])
# calculate face position in camera space
face_shape_r = np.matmul(face_shape,rotation)
face_shape_t = face_shape_r + np.reshape(translation,[1,1,3])
face_shape_t = np.matmul(face_shape_t,reverse_z) + camera_pos
# calculate projection of face vertex using perspective projection
aug_projection = np.matmul(face_shape_t,np.transpose(p_matrix,[0,2,1]))
face_projection = aug_projection[:,:,0:2]/np.reshape(aug_projection[:,:,2],[1,np.shape(aug_projection)[1],1])
z_buffer = np.reshape(aug_projection[:,:,2],[1,-1,1])
return face_projection,z_buffer
# compute vertex color using face_texture and SH function lighting approximation
# input: face_texture with shape [1,N,3]
# norm with shape [1,N,3]
# gamma with shape [1,27]
# output: face_color with shape [1,N,3], RGB order, range from 0-255
# lighting with shape [1,N,3], color under uniform texture
def Illumination_layer(face_texture,norm,gamma):
num_vertex = np.shape(face_texture)[1]
init_lit = np.array([0.8,0,0,0,0,0,0,0,0])
gamma = np.reshape(gamma,[-1,3,9])
gamma = gamma + np.reshape(init_lit,[1,1,9])
# parameter of 9 SH function
a0 = np.pi
a1 = 2*np.pi/np.sqrt(3.0)
a2 = 2*np.pi/np.sqrt(8.0)
c0 = 1/np.sqrt(4*np.pi)
c1 = np.sqrt(3.0)/np.sqrt(4*np.pi)
c2 = 3*np.sqrt(5.0)/np.sqrt(12*np.pi)
Y0 = np.tile(np.reshape(a0*c0,[1,1,1]),[1,num_vertex,1])
Y1 = np.reshape(-a1*c1*norm[:,:,1],[1,num_vertex,1])
Y2 = np.reshape(a1*c1*norm[:,:,2],[1,num_vertex,1])
Y3 = np.reshape(-a1*c1*norm[:,:,0],[1,num_vertex,1])
Y4 = np.reshape(a2*c2*norm[:,:,0]*norm[:,:,1],[1,num_vertex,1])
Y5 = np.reshape(-a2*c2*norm[:,:,1]*norm[:,:,2],[1,num_vertex,1])
Y6 = np.reshape(a2*c2*0.5/np.sqrt(3.0)*(3*np.square(norm[:,:,2])-1),[1,num_vertex,1])
Y7 = np.reshape(-a2*c2*norm[:,:,0]*norm[:,:,2],[1,num_vertex,1])
Y8 = np.reshape(a2*c2*0.5*(np.square(norm[:,:,0])-np.square(norm[:,:,1])),[1,num_vertex,1])
Y = np.concatenate([Y0,Y1,Y2,Y3,Y4,Y5,Y6,Y7,Y8],axis=2)
# Y shape:[batch,N,9].
lit_r = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,0,:],2)),2) #[batch,N,9] * [batch,9,1] = [batch,N]
lit_g = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,1,:],2)),2)
lit_b = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,2,:],2)),2)
# shape:[batch,N,3]
face_color = np.stack([lit_r*face_texture[:,:,0],lit_g*face_texture[:,:,1],lit_b*face_texture[:,:,2]],axis = 2)
lighting = np.stack([lit_r,lit_g,lit_b],axis = 2)*128
return face_color,lighting
# face reconstruction with coeff and BFM model
def Reconstruction(coeff,facemodel):
id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff)
# compute face shape
face_shape = Shape_formation(id_coeff, ex_coeff, facemodel)
# compute vertex texture(albedo)
face_texture = Texture_formation(tex_coeff, facemodel)
# vertex normal
face_norm = Compute_norm(face_shape,facemodel)
# rotation matrix
rotation = Compute_rotation_matrix(angles)
face_norm_r = np.matmul(face_norm,rotation)
# compute vertex projection on image plane (with image sized 224*224)
face_projection,z_buffer = Projection_layer(face_shape,rotation,translation)
face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2)
# compute 68 landmark on image plane
landmarks_2d = face_projection[:,facemodel.keypoints,:]
# compute vertex color using SH function lighting approximation
face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma)
# vertex index for each face of BFM model
tri = facemodel.tri
return face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d
# def Reconstruction_for_render(coeff,facemodel):
# id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff)
# face_shape = Shape_formation(id_coeff, ex_coeff, facemodel)
# face_texture = Texture_formation(tex_coeff, facemodel)
# face_norm = Compute_norm(face_shape,facemodel)
# rotation = Compute_rotation_matrix(angles)
# face_shape_r = np.matmul(face_shape,rotation)
# face_shape_r = face_shape_r + np.reshape(translation,[1,1,3])
# face_norm_r = np.matmul(face_norm,rotation)
# face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma)
# tri = facemodel.face_buf
# return face_shape_r,face_norm_r,face_color,tri