add interactive prediction and register block (#70)
This commit is contained in:
Родитель
47008bb4d6
Коммит
676c8283bd
|
@ -42,6 +42,7 @@ class LearningMachine(object):
|
|||
device = 'GPU' if 'cuda' in emb_weight_device else 'CPU'
|
||||
logging.info(
|
||||
"The embedding matrix is on %s now, you can modify the weight_on_gpu parameter to change embeddings weight device." % device)
|
||||
logging.info("="*100 + '\n' + "*"*15 + "Model Achitecture" + "*"*15)
|
||||
logging.info(self.model)
|
||||
#logging.info("Total parameters: %d; trainable parameters: %d" % (get_param_num(self.model), get_trainable_param_num(self.model)))
|
||||
logging.info("Total trainable parameters: %d" % (get_trainable_param_num(self.model)))
|
||||
|
@ -91,6 +92,7 @@ class LearningMachine(object):
|
|||
|
||||
def train(self, optimizer, loss_fn):
|
||||
self.model.train()
|
||||
logging.info("="*100 + '\n' + "*"*15 + 'Prepare data for training' + "*"*15)
|
||||
if not self.conf.train_data_path.endswith('.pkl'):
|
||||
train_data, train_length, train_target = self.problem.encode(self.conf.train_data_path, self.conf.file_columns,
|
||||
self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
|
||||
|
@ -134,6 +136,7 @@ class LearningMachine(object):
|
|||
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
|
||||
streaming_recoder = StreamingRecorder(['prediction', 'answer_text'])
|
||||
|
||||
logging.info("=" * 100 + '\n' + "*" * 15 + 'Start training' + "*" * 15)
|
||||
while not stop_training and epoch <= self.conf.max_epoch:
|
||||
logging.info('Training: Epoch ' + str(epoch))
|
||||
|
||||
|
@ -789,6 +792,107 @@ class LearningMachine(object):
|
|||
|
||||
fin.close()
|
||||
|
||||
def interactive(self, sample, file_columns, predict_fields=['prediction'], predict_mode='batch'):
|
||||
""" interactive prediction
|
||||
|
||||
Args:
|
||||
file_columns: representation the columns of sample
|
||||
predict_mode: interactive|batch(need a predict file)
|
||||
"""
|
||||
predict_data, predict_length, _, _, _ = \
|
||||
self.problem.encode_data_list(sample, file_columns, self.conf.input_types, self.conf.object_inputs, None,
|
||||
self.conf.min_sentence_len, self.conf.extra_feature, self.conf.max_lengths,
|
||||
self.conf.fixed_lengths, predict_mode=predict_mode)
|
||||
if predict_data is None:
|
||||
return 'Wrong Case!'
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
data_batches, length_batches, _ = \
|
||||
get_batches(self.problem, predict_data, predict_length, None, 1,
|
||||
self.conf.input_types, None, permutate=False, transform_tensor=True, predict_mode=predict_mode)
|
||||
streaming_recoder = StreamingRecorder(predict_fields)
|
||||
|
||||
key_random = random.choice(
|
||||
list(length_batches[0].keys()).remove('target') if 'target' in list(length_batches[0].keys()) else
|
||||
list(length_batches[0].keys()))
|
||||
param_list, inputs_desc, length_desc = transform_params2tensors(data_batches[0], length_batches[0])
|
||||
logits = self.model(inputs_desc, length_desc, *param_list)
|
||||
|
||||
logits_softmax = {}
|
||||
if isinstance(self.model, nn.DataParallel):
|
||||
for tmp_output_layer_id in self.model.module.output_layer_id:
|
||||
if isinstance(self.model.module.layers[tmp_output_layer_id], Linear) and \
|
||||
(not self.model.module.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
|
||||
logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
|
||||
logits[tmp_output_layer_id], dim=-1)
|
||||
else:
|
||||
logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
|
||||
else:
|
||||
for tmp_output_layer_id in self.model.output_layer_id:
|
||||
if isinstance(self.model.layers[tmp_output_layer_id], Linear) and \
|
||||
(not self.model.layers[tmp_output_layer_id].layer_conf.last_hidden_softmax):
|
||||
logits_softmax[tmp_output_layer_id] = nn.functional.softmax(
|
||||
logits[tmp_output_layer_id], dim=-1)
|
||||
else:
|
||||
logits_softmax[tmp_output_layer_id] = logits[tmp_output_layer_id]
|
||||
|
||||
if ProblemTypes[self.problem.problem_type] == ProblemTypes.sequence_tagging:
|
||||
logits = list(logits.values())[0]
|
||||
if isinstance(get_layer_class(self.model, tmp_output_layer_id), CRF):
|
||||
forward_score, scores, masks, tag_seq, transitions, layer_conf = logits
|
||||
prediction_indices = tag_seq.cpu().numpy()
|
||||
else:
|
||||
logits_softmax = list(logits_softmax.values())[0]
|
||||
# Transform output shapes for metric evaluation
|
||||
# for seq_tag_f1 metric
|
||||
prediction_indices = logits_softmax.data.max(2)[1].cpu().numpy() # [batch_size, seq_len]
|
||||
prediction_batch = self.problem.decode(prediction_indices, length_batches[0][key_random].numpy())
|
||||
for prediction_sample in prediction_batch:
|
||||
streaming_recoder.record('prediction', " ".join(prediction_sample))
|
||||
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.classification:
|
||||
logits = list(logits.values())[0]
|
||||
logits_softmax = list(logits_softmax.values())[0]
|
||||
prediction_indices = logits_softmax.data.max(1)[1].cpu().numpy()
|
||||
|
||||
for field in predict_fields:
|
||||
if field == 'prediction':
|
||||
streaming_recoder.record(field,
|
||||
self.problem.decode(prediction_indices,
|
||||
length_batches[0][key_random].numpy()))
|
||||
elif field == 'confidence':
|
||||
prediction_scores = logits_softmax.cpu().data.numpy()
|
||||
for prediction_score, prediction_idx in zip(prediction_scores, prediction_indices):
|
||||
streaming_recoder.record(field, prediction_score[prediction_idx])
|
||||
elif field.startswith('confidence') and field.find('@') != -1:
|
||||
label_specified = field.split('@')[1]
|
||||
label_specified_idx = self.problem.output_dict.id(label_specified)
|
||||
confidence_specified = torch.index_select(logits_softmax.cpu(), 1, torch.tensor([label_specified_idx], dtype=torch.long)).squeeze(1)
|
||||
streaming_recoder.record(field, confidence_specified.data.numpy())
|
||||
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.regression:
|
||||
logits = list(logits.values())[0]
|
||||
# logits_softmax is unuseful for regression task!
|
||||
logits_softmax = list(logits_softmax.values())[0]
|
||||
logits_flat = logits.squeeze(1)
|
||||
prediction_scores = logits_flat.detach().cpu().numpy()
|
||||
streaming_recoder.record_one_row([prediction_scores])
|
||||
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
|
||||
for key, value in logits.items():
|
||||
logits[key] = value.squeeze()
|
||||
for key, value in logits_softmax.items():
|
||||
logits_softmax[key] = value.squeeze()
|
||||
passage_identify = None
|
||||
for type_key in data_batches[0].keys():
|
||||
if 'p' in type_key.lower():
|
||||
passage_identify = type_key
|
||||
break
|
||||
if not passage_identify:
|
||||
raise Exception('MRC task need passage information.')
|
||||
prediction = self.problem.decode(logits_softmax, lengths=length_batches[0][passage_identify],
|
||||
batch_data=data_batches[0][passage_identify])
|
||||
streaming_recoder.record_one_row([prediction])
|
||||
|
||||
return "\t".join([str(streaming_recoder.get(field)[0]) for field in predict_fields])
|
||||
|
||||
def load_model(self, model_path):
|
||||
if self.use_gpu is True:
|
||||
self.model = torch.load(model_path)
|
||||
|
|
|
@ -406,7 +406,8 @@ class ModelConf(object):
|
|||
"The configuration file %s is illegal. There should be an item configuration[%s], "
|
||||
"but the item %s is not found." % (self.conf_path, "][".join(error_keys), key))
|
||||
else:
|
||||
print("configuration[%s] is not found in %s, use default value %s" % ("][".join(error_keys), self.conf_path, repr(default)))
|
||||
# print("configuration[%s] is not found in %s, use default value %s" %
|
||||
# ("][".join(error_keys), self.conf_path, repr(default)))
|
||||
item = default
|
||||
|
||||
return item
|
||||
|
|
|
@ -458,7 +458,7 @@ This task is to train a query regression model to learn from a heavy teacher mod
|
|||
3. Calculate AUC metric
|
||||
```bash
|
||||
cd PROJECT_ROOT
|
||||
python tools/calculate_AUC.py --input_file models/kdqbc_bilstmattn_cnn/train/predict.tsv --predict_index 2 --label_index 1
|
||||
python tools/calculate_auc.py --input_file models/kdqbc_bilstmattn_cnn/train/predict.tsv --predict_index 2 --label_index 1
|
||||
```
|
||||
|
||||
*Tips: you can try different models by running different JSON config files.*
|
||||
|
@ -502,7 +502,7 @@ This task is to train a query-passage regression model to learn from a heavy tea
|
|||
3. Calculate AUC metric
|
||||
```bash
|
||||
cd PROJECT_ROOT
|
||||
python tools/calculate_AUC.py --input_file=models/kdtm_match_linearAttn/predict.tsv --predict_index=3 --label_index=2
|
||||
python tools/calculate_auc.py --input_file=models/kdtm_match_linearAttn/predict.tsv --predict_index=3 --label_index=2
|
||||
```
|
||||
|
||||
*Tips: you can try different models by running different JSON config files.*
|
||||
|
@ -574,7 +574,7 @@ Sequence Labeling is an important NLP task, which includes NER, Slot Tagging, Po
|
|||
|
||||
- NeuronBlocks support both BIO and BIOES tag schemes.
|
||||
- The IOB scheme is not supported, because of its worse performance in most [experiment](https://arxiv.org/pdf/1707.06799.pdf).
|
||||
- NeuronBlocks provides a [script](./tools/taggingSchemes_Converter.py) that converts the tag scheme among IOB/BIO/BIOES (NOTE: the script only supports tsv file which has data and label in two columns).
|
||||
- NeuronBlocks provides a [script](tools/tagging_schemes_converter.py) that converts the tag scheme among IOB/BIO/BIOES (NOTE: the script only supports tsv file which has data and label in two columns).
|
||||
|
||||
- ***Usages***
|
||||
|
||||
|
|
|
@ -447,7 +447,7 @@ This task is to train a query regression model to learn from a heavy teacher mod
|
|||
3. Calculate AUC metric
|
||||
```bash
|
||||
cd PROJECT_ROOT
|
||||
python tools/calculate_AUC.py --input_file models/kdqbc_bilstmattn_cnn/train/predict.tsv --predict_index 2 --label_index 1
|
||||
python tools/calculate_auc.py --input_file models/kdqbc_bilstmattn_cnn/train/predict.tsv --predict_index 2 --label_index 1
|
||||
```
|
||||
|
||||
*Tips: you can try different models by running different JSON config files.*
|
||||
|
@ -491,7 +491,7 @@ This task is to train a query-passage regression model to learn from a heavy tea
|
|||
3. Calculate AUC metric
|
||||
```bash
|
||||
cd PROJECT_ROOT
|
||||
python tools/calculate_AUC.py --input_file=models/kdtm_match_linearAttn/predict.tsv --predict_index=3 --label_index=2
|
||||
python tools/calculate_auc.py --input_file=models/kdtm_match_linearAttn/predict.tsv --predict_index=3 --label_index=2
|
||||
```
|
||||
|
||||
*Tips: you can try different models by running different JSON config files.*
|
||||
|
@ -564,7 +564,7 @@ This task is to train a query-passage regression model to learn from a heavy tea
|
|||
|
||||
- NeuronBlocks 支持 BIO 和 BIOES 标注策略。
|
||||
- IOB 标注标注是不被支持的,因为在大多[实验](https://arxiv.org/pdf/1707.06799.pdf)中它具有很差的表现。
|
||||
- NeuronBlocks 提供一个在不同标注策略(IOB/BIO/BIOES)中的[转化脚本](./tools/taggingSchemes_Converter.py)(脚本仅支持具有 数据和标签 的两列tsv文件输入)。
|
||||
- NeuronBlocks 提供一个在不同标注策略(IOB/BIO/BIOES)中的[转化脚本](tools/tagging_schemes_converter.py)(脚本仅支持具有 数据和标签 的两列tsv文件输入)。
|
||||
|
||||
- ***用法***
|
||||
|
||||
|
|
|
@ -7,7 +7,9 @@ ZIPTOOL="unzip"
|
|||
|
||||
# GloVe
|
||||
echo $glovepath
|
||||
mkdir GloVe
|
||||
if [ ! -d "/GloVe/"];then
|
||||
mkdir GloVe
|
||||
fi
|
||||
curl -LO $glovepath
|
||||
$ZIPTOOL glove.840B.300d.zip -d GloVe/
|
||||
rm glove.840B.300d.zip
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
preprocess_exec="sed -f tokenizer.sed"
|
||||
|
||||
glovepath='http://nlp.stanford.edu/data/glove.6B.zip'
|
||||
|
||||
ZIPTOOL="unzip"
|
||||
|
||||
# GloVe
|
||||
echo $glovepath
|
||||
if [ ! -d "/GloVe/"];then
|
||||
mkdir GloVe
|
||||
fi
|
||||
curl -LO $glovepath
|
||||
$ZIPTOOL glove.6B.zip -d GloVe/
|
||||
rm glove.6B.zip
|
||||
|
|
@ -14,11 +14,11 @@
|
|||
},
|
||||
"add_start_end_for_seq": false,
|
||||
"file_header": {
|
||||
"word": 0,
|
||||
"sequence": 0,
|
||||
"tag": 1
|
||||
},
|
||||
"model_inputs": {
|
||||
"words": ["word"]
|
||||
"words": ["sequence"]
|
||||
},
|
||||
"target": ["tag"]
|
||||
},
|
||||
|
@ -50,7 +50,7 @@
|
|||
"use_gpu": true,
|
||||
"batch_size": 10,
|
||||
"batch_num_to_show_results": 500,
|
||||
"max_epoch": 50,
|
||||
"max_epoch": 2,
|
||||
"valid_times_per_epoch": 1
|
||||
},
|
||||
"architecture":[
|
||||
|
@ -59,7 +59,7 @@
|
|||
"weight_on_gpu": true,
|
||||
"conf": {
|
||||
"word": {
|
||||
"cols": ["word"],
|
||||
"cols": ["sequence"],
|
||||
"dim": 100
|
||||
}
|
||||
}
|
||||
|
|
39
predict.py
39
predict.py
|
@ -34,13 +34,46 @@ def main(params):
|
|||
lm = LearningMachine('predict', conf, problem, vocab_info=None, initialize=False, use_gpu=conf.use_gpu)
|
||||
lm.load_model(conf.previous_model_path)
|
||||
|
||||
logging.info('Predicting %s with the model saved at %s' % (conf.predict_data_path, conf.previous_model_path))
|
||||
lm.predict(conf.predict_data_path, conf.predict_output_path, conf.predict_file_columns, conf.predict_fields)
|
||||
logging.info("Predict done! The predict result: %s" % conf.predict_output_path)
|
||||
if params.predict_mode == 'batch':
|
||||
logging.info('Predicting %s with the model saved at %s' % (conf.predict_data_path, conf.previous_model_path))
|
||||
if params.predict_mode == 'batch':
|
||||
lm.predict(conf.predict_data_path, conf.predict_output_path, conf.predict_file_columns, conf.predict_fields)
|
||||
logging.info("Predict done! The predict result: %s" % conf.predict_output_path)
|
||||
elif params.predict_mode == 'interactive':
|
||||
print('='*80)
|
||||
task_type = str(ProblemTypes[problem.problem_type]).split('.')[1]
|
||||
sample_format = list(conf.predict_file_columns.keys())
|
||||
target_ = conf.conf['inputs'].get('target', None)
|
||||
target_list = list(target_) if target_ else []
|
||||
for single_element in sample_format[:]:
|
||||
if single_element in target_list:
|
||||
sample_format.remove(single_element)
|
||||
predict_file_columns = {}
|
||||
for index, single in enumerate(sample_format):
|
||||
predict_file_columns[single] = index
|
||||
print('Enabling Interactive Inference Mode for %s Task...' % (task_type.upper()))
|
||||
print('%s Task Interactive. The sample format is <%s>' % (task_type.upper(), ', '.join(sample_format)))
|
||||
case_cnt = 1
|
||||
while True:
|
||||
print('Case%d:' % case_cnt)
|
||||
sample = []
|
||||
for single in sample_format:
|
||||
temp_ = input('\t%s: ' % single)
|
||||
if temp_.lower() == 'exit':
|
||||
exit(0)
|
||||
sample.append(temp_)
|
||||
sample = '\t'.join(sample)
|
||||
result = lm.interactive([sample], predict_file_columns, conf.predict_fields, params.predict_mode)
|
||||
print('\tInference result: %s' % result)
|
||||
case_cnt += 1
|
||||
else:
|
||||
raise Exception('Predict mode support interactive|batch, get %s' % params.predict_mode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Prediction')
|
||||
parser.add_argument("--conf_path", type=str, help="configuration path")
|
||||
parser.add_argument("--predict_mode", type=str, default='batch', help='interactive|batch')
|
||||
parser.add_argument("--predict_data_path", type=str, help='specify another predict data path, instead of the one defined in configuration file')
|
||||
parser.add_argument("--previous_model_path", type=str, help='load model trained previously.')
|
||||
parser.add_argument("--predict_output_path", type=str, help='specify another prediction output path, instead of conf[outputs][save_base_dir] + conf[outputs][predict_output_name] defined in configuration file')
|
||||
|
|
15
problem.py
15
problem.py
|
@ -425,7 +425,7 @@ class Problem():
|
|||
yield output_data, lengths, target, cnt_legal, cnt_illegal
|
||||
|
||||
def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
|
||||
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None, predict_mode='batch'):
|
||||
data = dict()
|
||||
lengths = dict()
|
||||
char_emb = True if 'char' in [single_input_type.lower() for single_input_type in input_types] else False
|
||||
|
@ -483,11 +483,14 @@ class Problem():
|
|||
line_split = line.rstrip().split('\t')
|
||||
cnt_all += 1
|
||||
if len(line_split) != len(file_columns):
|
||||
# logging.warning("Current line is inconsistent with configuration/inputs/file_header. Ingore now. %s" % line)
|
||||
cnt_illegal += 1
|
||||
if cnt_illegal / cnt_all > 0.33:
|
||||
raise PreprocessError('The illegal data is too much. Please check the number of data columns or text token version.')
|
||||
continue
|
||||
if predict_mode == 'batch':
|
||||
cnt_illegal += 1
|
||||
if cnt_illegal / cnt_all > 0.33:
|
||||
raise PreprocessError('The illegal data is too much. Please check the number of data columns or text token version.')
|
||||
continue
|
||||
else:
|
||||
print('\tThe case is illegal! Please check your case and input again!')
|
||||
return [None]*5
|
||||
# cnt_legal += 1
|
||||
length_appended_set = set() # to store branches whose length have been appended to lengths[branch]
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
def get_block_path(block_name, path='./block_zoo'):
|
||||
''' find the block_name.py file in block_zoo
|
||||
Args:
|
||||
block_name: the name need to be registered. eg. BiLSTM/ CRF
|
||||
'''
|
||||
get_dir = os.listdir(path)
|
||||
for single in get_dir:
|
||||
sub_dir = os.path.join(path, single)
|
||||
if os.path.isdir(sub_dir):
|
||||
result = get_block_path(block_name, path=sub_dir)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
if block_name + '.py' == single:
|
||||
return sub_dir
|
||||
return None
|
||||
|
||||
|
||||
def write_file(new_block_path, file_path):
|
||||
init_path = os.path.join(file_path, '__init__.py')
|
||||
diff = new_block_path[len(file_path):].split('/')
|
||||
if diff[0] == '':
|
||||
diff.pop(0)
|
||||
# delete '.py' in the last str
|
||||
diff[-1] = diff[-1][:-3]
|
||||
line = 'from .' + diff[0] + ' import ' + diff[-1] + ', ' + diff[-1] + 'Conf'
|
||||
with open(init_path, 'a', encoding='utf-8') as fin:
|
||||
fin.write('\n' + line + '\n')
|
||||
|
||||
|
||||
def register(block_name, new_block_path):
|
||||
''' Add import code in the corresponding file. eg. block_zoo/__init__.py or block_zoo/subdir/__init__.py
|
||||
|
||||
'''
|
||||
# check if block exist or not
|
||||
if new_block_path:
|
||||
block_path_split = new_block_path.split('/')
|
||||
for i in range(len(block_path_split)-1, 1, -1):
|
||||
# need_add_file.append(os.path.join('/'.join(block_path_split[:i])))
|
||||
write_file(new_block_path, os.path.join('/'.join(block_path_split[:i])))
|
||||
print('The block %s is registered successfully.' % block_name)
|
||||
else:
|
||||
raise Exception('The %s.py file does not exist! Please check your program or file name.' % block_name)
|
||||
|
||||
|
||||
def main(params):
|
||||
new_block_path = get_block_path(params.block_name)
|
||||
register(params.block_name, new_block_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parse = argparse.ArgumentParser(description='Register Block')
|
||||
parse.add_argument("--block_name", type=str, help="block name want to be registered")
|
||||
params, _ = parse.parse_known_args()
|
||||
assert params.block_name, 'Please specify a block_name via --block_name'
|
||||
main(params)
|
1
train.py
1
train.py
|
@ -171,6 +171,7 @@ def main(params):
|
|||
# data preprocessing
|
||||
## build dictionary when (not in finetune model) and (not use cache or cache invalid)
|
||||
if (not conf.pretrained_model_path) and ((conf.use_cache == False) or cache.dictionary_invalid):
|
||||
logging.info("="*100)
|
||||
logging.info("Preprocessing... Depending on your corpus size, this step may take a while.")
|
||||
# modify train_data_path to [train_data_path, valid_data_path, test_data_path]
|
||||
# remember the test_data may be None
|
||||
|
|
|
@ -173,7 +173,7 @@ def corpus_permutation(*corpora):
|
|||
return corpora_perm
|
||||
|
||||
|
||||
def get_batches(problem, data, length, target, batch_size, input_types, pad_ids=None, permutate=False, transform_tensor=True):
|
||||
def get_batches(problem, data, length, target, batch_size, input_types, pad_ids=None, permutate=False, transform_tensor=True, predict_mode='batch'):
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
@ -232,7 +232,8 @@ def get_batches(problem, data, length, target, batch_size, input_types, pad_ids=
|
|||
target_batches: ndarray/Variable shape: [number of batches, batch_size, targets]
|
||||
|
||||
"""
|
||||
logging.info("Start making batches")
|
||||
if predict_mode == 'batch':
|
||||
logging.info("Start making batches")
|
||||
if permutate is True:
|
||||
#CAUTION! data and length would be revised
|
||||
data = copy.deepcopy(data)
|
||||
|
@ -392,7 +393,8 @@ def get_batches(problem, data, length, target, batch_size, input_types, pad_ids=
|
|||
|
||||
target_batches.append(target_batch)
|
||||
|
||||
logging.info("Batches got!")
|
||||
if predict_mode == 'batch':
|
||||
logging.info("Batches got!")
|
||||
return data_batches, length_batches, target_batches
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче