From 72df5c876f368ae4a1b594e7a740ff966dbbd3ba Mon Sep 17 00:00:00 2001 From: JiaqiGuo Date: Sat, 7 Dec 2019 11:24:42 +0900 Subject: [PATCH] Fix bug in measurement --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 007f5be..d50388c 100644 --- a/train.py +++ b/train.py @@ -103,11 +103,11 @@ def train(args): print(tb) else: utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) - json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, + json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, beam_size=args.beam_size) - acc = utils.eval_acc(json_datas, val_sql_data) + # acc = utils.eval_acc(json_datas, val_sql_data) - print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (acc, acc, acc,)) + print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (sketch_acc, acc, acc,)) if __name__ == '__main__':