зеркало из https://github.com/microsoft/FERPlus.git
Fix cross-entropy.
This commit is contained in:
Родитель
0714f4ee36
Коммит
45edecd6e9
47
README.md
47
README.md
|
@ -11,6 +11,53 @@ We also provide a simple parsing code in Python to show how to parse the new lab
|
|||
|
||||
The format of the CSV file is as follows: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF. Columns "usage" is the same as the original FER label to differentiate between training, public test and private test sets. The other columns are the vote count for each emotion with the addition of unknown and NF (Not a Face).
|
||||
|
||||
## Training
|
||||
We also provide a training code with implementation for all the training modes (majority, probability, cross entropy and multi-label) described in https://arxiv.org/abs/1608.01041. The training code uses MS Cognitive Toolkit (formerly CNTK) available in: https://github.com/Microsoft/CNTK.
|
||||
|
||||
After installing Cognitive Toolkit and downloading the dataset (we will discuss the dataset layout next), you can simply run the following to start the training:
|
||||
|
||||
#### For majority voting mode
|
||||
```
|
||||
python train.py -d <dataset base folder> -m majority
|
||||
```
|
||||
|
||||
#### For probability mode
|
||||
```
|
||||
python train.py -d <dataset base folder> -m probability
|
||||
```
|
||||
|
||||
#### For cross entropy mode
|
||||
```
|
||||
python train.py -d <dataset base folder> -m crossentropy
|
||||
```
|
||||
|
||||
#### For multi-target mode
|
||||
```
|
||||
python train.py -d <dataset base folder> -m multi_target
|
||||
```
|
||||
|
||||
## FER+ layout for Training
|
||||
There is a folder named data that has the following layout:
|
||||
|
||||
```
|
||||
/data
|
||||
/FER2013Test
|
||||
label.csv
|
||||
/FER2013Train
|
||||
label.csv
|
||||
/FER2013Valid
|
||||
label.csv
|
||||
```
|
||||
*label.csv* in each folder contains the actual label for each image, the image name is in the following format: ferXXXXXXXX.png, where XXXXXXXX is the row index of the original FER csv file. So here the names of the first few images:
|
||||
|
||||
```
|
||||
fer0000000.png
|
||||
fer0000001.png
|
||||
fer0000002.png
|
||||
fer0000003.png
|
||||
```
|
||||
The folders don't contain the actual images, you will need to download them from https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data, then extract the images from the FER csv file in such a way, that all images corresponding to "Training" go to FER2013Train folder, all images corresponding to "PublicTest" go to FER2013Valid folder and all images corresponding to "PrivateTest" go to FER2013Test folder.
|
||||
|
||||
# Citation
|
||||
If you use the new FER+ label or the sample code or part of it in your research, please cite the following:
|
||||
|
||||
|
|
|
@ -214,7 +214,7 @@ class FERPlusReader(object):
|
|||
emotion[np.argmax(emotion_raw)] = maxval
|
||||
else:
|
||||
emotion = emotion_unknown # force setting as unknown
|
||||
elif (mode == 'probability') or (mode == 'crossentropy'):
|
||||
elif (mode == 'probability') or (mode == 'crossentropy'):
|
||||
sum_part = 0
|
||||
count = 0
|
||||
valid_emotion = True
|
||||
|
|
15
src/train.py
15
src/train.py
|
@ -40,12 +40,11 @@ def cost_func(training_mode, prediction, target):
|
|||
multiple labels exactly the same.
|
||||
'''
|
||||
train_loss = None
|
||||
|
||||
if training_mode == 'majority' or training_mode == 'probability' or training_mode == 'crossentropy':
|
||||
# Cross Entropy.
|
||||
train_loss = ct.reduce_sum(ct.minus(ct.reduce_log_sum(prediction, axis=0), ct.reduce_sum(ct.element_times(target, prediction), axis=0)))
|
||||
train_loss = ct.negate(ct.reduce_sum(ct.element_times(target, ct.log(prediction)), axis=-1))
|
||||
elif training_mode == 'multi_target':
|
||||
train_loss = ct.negate(ct.log(ct.reduce_max(ct.element_times(target, prediction), axis=0)))
|
||||
train_loss = ct.negate(ct.log(ct.reduce_max(ct.element_times(target, prediction), axis=-1)))
|
||||
|
||||
return train_loss
|
||||
|
||||
|
@ -91,15 +90,17 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
|||
minibatch_size = 32
|
||||
|
||||
# Training config
|
||||
lr_schedule = model.learning_rate
|
||||
lr_schedule = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
|
||||
lr_per_minibatch = learning_rate_schedule(lr_schedule, epoch_size, UnitType.minibatch)
|
||||
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9), epoch_size)
|
||||
|
||||
# loss and error cost
|
||||
train_loss = cost_func(training_mode, pred, label_var)
|
||||
pe = classification_error(pred, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
# construct the trainer
|
||||
learner = sgd(pred.parameters, lr = lr_per_minibatch)
|
||||
learner = momentum_sgd(z.parameters, lr = lr_per_minibatch, momentum = momentum_time_constant)
|
||||
#learner = sgd(pred.parameters, lr = lr_per_minibatch)
|
||||
trainer = Trainer(z, train_loss, pe, learner)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
|
|
Загрузка…
Ссылка в новой задаче