Update code for supporting predicting on custom data.
This commit is contained in:
Родитель
5ddfd6542b
Коммит
49293a895a
|
@ -22,7 +22,8 @@ If you find our code useful, please consider citing our paper:
|
|||
- [Install Requirements](#requirements)
|
||||
- [Prepare Dataset](#data)
|
||||
- [Train Model](#train)
|
||||
- [Predict Model](#predict)
|
||||
- [Evaluate Model](#evaluate)
|
||||
- [Predict on Custom Data](#predict-on-custom-data)
|
||||
- [Demo on Web](#demo)
|
||||
- [Pretrained Weights](#experiment)
|
||||
- [Fine-grained Analysis](#analysis)
|
||||
|
@ -160,9 +161,9 @@ allennlp train -s %model_file% %config_file% ^
|
|||
-o {"""model.serialization_dir""":"""%model_file%""","""random_seed""":"""%seed%""","""numpy_seed""":"""%seed%""","""pytorch_seed""":"""%seed%""","""dataset_reader.tables_file""":"""%tables_file%""","""dataset_reader.database_path""":"""%database_path%""","""train_data_path""":"""%train_data_path%""","""validation_data_path""":"""%validation_data_path%""","""model.text_embedder.tokens.pretrained_file""":"""%pretrained_file%""","""model.dataset_path""":"""%dataset_path%"""}
|
||||
```
|
||||
|
||||
## Predict
|
||||
## Evaluate
|
||||
|
||||
You could predict SQLs using trained model checkpoint file (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`) using the following command:
|
||||
You could predict and evaluate SQLs using trained model checkpoint file (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`) using the following command:
|
||||
|
||||
- Under Linux
|
||||
```bash
|
||||
|
@ -201,6 +202,10 @@ allennlp predict ^
|
|||
%model_file%/model.tar.gz %validation_out_file
|
||||
```
|
||||
|
||||
## Predict On Custom Data
|
||||
|
||||
Our code also supports function call to predict SQls on custom data. You could find it in `predict.py`. Before running it, you should firstly store your own database into paths corresponding to the arguments of `PredictManager` with the same format as SParC/CoSQL.
|
||||
|
||||
## Demo
|
||||
|
||||
You could also host a demo page using the following command using a well-trained archived model (e.g. `checkpoints_sparc/sparc_concat_model/model.tar.gz`):
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from allennlp.models.archival import load_archive
|
||||
from allennlp.predictors.predictor import Predictor
|
||||
# WARNING: Do not exclude these imports
|
||||
from predictor.sparc_predictor import SparcPredictor
|
||||
from dataset_reader.sparc_reader import SparcDatasetReader
|
||||
from models.sparc_parser import SparcParser
|
||||
|
||||
|
||||
class PredictManager:
|
||||
|
||||
def __init__(self, archive_file, tables_file, database_path):
|
||||
overrides = "{\"dataset_reader.tables_file\":\"" + tables_file + "\",\"dataset_reader.database_path\":" +\
|
||||
"\"" + database_path + "\"}"
|
||||
archive = load_archive(archive_file,
|
||||
overrides=overrides)
|
||||
self.predictor = Predictor.from_archive(
|
||||
archive, predictor_name="sparc")
|
||||
|
||||
def predict_result(self, ques_inter: str, ques_database: str):
|
||||
param = {
|
||||
"database_id": ques_database,
|
||||
"question": ques_inter
|
||||
}
|
||||
restate = self.predictor.predict_json(param)["best_predict_sql"]
|
||||
return restate
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
manager = PredictManager(archive_file="model.tar.gz",
|
||||
tables_file="dataset_sparc/tables.json",
|
||||
database_path="dataset_sparc/database")
|
||||
# the input dialogue is separate by `;`, and the second argument is database_id
|
||||
result = manager.predict_result("What are all the airlines;Of these, which is Jetblue Airways", "flight_2")
|
||||
print(result)
|
Загрузка…
Ссылка в новой задаче