Update code for supporting predicting on custom data.

This commit is contained in:
SivilTaram 2020-11-22 00:25:42 +08:00
Родитель 5ddfd6542b
Коммит 49293a895a
2 изменённых файлов: 42 добавлений и 3 удалений

Просмотреть файл

@ -22,7 +22,8 @@ If you find our code useful, please consider citing our paper:
- [Install Requirements](#requirements)
- [Prepare Dataset](#data)
- [Train Model](#train)
- [Predict Model](#predict)
- [Evaluate Model](#evaluate)
- [Predict on Custom Data](#predict-on-custom-data)
- [Demo on Web](#demo)
- [Pretrained Weights](#experiment)
- [Fine-grained Analysis](#analysis)
@ -160,9 +161,9 @@ allennlp train -s %model_file% %config_file% ^
-o {"""model.serialization_dir""":"""%model_file%""","""random_seed""":"""%seed%""","""numpy_seed""":"""%seed%""","""pytorch_seed""":"""%seed%""","""dataset_reader.tables_file""":"""%tables_file%""","""dataset_reader.database_path""":"""%database_path%""","""train_data_path""":"""%train_data_path%""","""validation_data_path""":"""%validation_data_path%""","""model.text_embedder.tokens.pretrained_file""":"""%pretrained_file%""","""model.dataset_path""":"""%dataset_path%"""}
```
## Predict
## Evaluate
You could predict SQLs using trained model checkpoint file (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`) using the following command:
You could predict and evaluate SQLs using trained model checkpoint file (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`) using the following command:
- Under Linux
```bash
@ -201,6 +202,10 @@ allennlp predict ^
%model_file%/model.tar.gz %validation_out_file
```
## Predict On Custom Data
Our code also supports function call to predict SQls on custom data. You could find it in `predict.py`. Before running it, you should firstly store your own database into paths corresponding to the arguments of `PredictManager` with the same format as SParC/CoSQL.
## Demo
You could also host a demo page using the following command using a well-trained archived model (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`):

Просмотреть файл

@ -0,0 +1,34 @@
from allennlp.models.archival import load_archive
from allennlp.predictors.predictor import Predictor
# WARNING: Do not exclude these imports
from predictor.sparc_predictor import SparcPredictor
from dataset_reader.sparc_reader import SparcDatasetReader
from models.sparc_parser import SparcParser
class PredictManager:
def __init__(self, archive_file, tables_file, database_path):
overrides = "{\"dataset_reader.tables_file\":\"" + tables_file + "\",\"dataset_reader.database_path\":" +\
"\"" + database_path + "\"}"
archive = load_archive(archive_file,
overrides=overrides)
self.predictor = Predictor.from_archive(
archive, predictor_name="sparc")
def predict_result(self, ques_inter: str, ques_database: str):
param = {
"database_id": ques_database,
"question": ques_inter
}
restate = self.predictor.predict_json(param)["best_predict_sql"]
return restate
if __name__ == '__main__':
manager = PredictManager(archive_file="model.tar.gz",
tables_file="dataset_sparc/tables.json",
database_path="dataset_sparc/database")
# the input dialogue is separate by `;`, and the second argument is database_id
result = manager.predict_result("What are all the airlines;Of these, which is Jetblue Airways", "flight_2")
print(result)