Adding summariration, minor fixes
This commit is contained in:
Родитель
dbd824621f
Коммит
00825c3109
|
@ -42,7 +42,7 @@ aml_run = he.get_context()
|
|||
|
||||
def doc_classification(task, model_type, n_epochs, batch_size, embeds_dropout, evaluate_every,
|
||||
use_cuda, max_seq_len, learning_rate, do_lower_case,
|
||||
register_model, save_model=True, early_stopping=True):
|
||||
register_model, save_model=True, early_stopping=False):
|
||||
|
||||
language = cu.params.get('language')
|
||||
|
||||
|
|
|
@ -100,7 +100,8 @@ class Data():
|
|||
_dir = f"{_base}/{self.model_dir}/"
|
||||
if os.path.isdir(_dir):
|
||||
### Deployed with multiple model objects in AML
|
||||
self.root_dir = f'{_base}/{os.listdir(_dir)[0]}/'
|
||||
self.root_dir = f'{_base}/'
|
||||
self.model_dir = f'{self.model_dir}/{os.listdir(_dir)[0]}/{self.model_dir}/'
|
||||
# NOTE: get's latest version of model
|
||||
else:
|
||||
### Deployed with single model objects in AML
|
||||
|
|
|
@ -48,7 +48,7 @@ class Rank():
|
|||
# Filter by classified label
|
||||
if cats is not None and cats != '':
|
||||
#TODO: does not work for lists
|
||||
_data = _data[_data.appliesTo.str.contains(cats)].reset_index(drop=True)
|
||||
_data = _data[_data.label_classification_simple.str.contains(cats)].reset_index(drop=True)
|
||||
logger.warning(f'[INFO] Reduced answer selection to {len(_data)} from {len(self.data)}.')
|
||||
|
||||
# BM25 Score threshold
|
||||
|
@ -60,7 +60,7 @@ class Rank():
|
|||
"""Used for inference
|
||||
NOTE: expects one input, one output given
|
||||
"""
|
||||
return self.run(dicts[0]['text'], cats=dicts[0]['cat'])[['question_clean','answer_text_clean','appliesTo','score']].to_dict(orient='records')
|
||||
return self.run(dicts[0]['text'], cats=dicts[0]['cat'])[['question_clean','answer_text_clean','label_classification_multi','score']].to_dict(orient='records')
|
||||
|
||||
def create_bm25():
|
||||
"""Function to create or update BM25 object"""
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
from summarizer import Summarizer
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
import re
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from nltk.corpus import stopwords
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from gensim.summarization.summarizer import summarize
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
''' BERTABS '''
|
||||
def summarizeText(text, minLength=60):
|
||||
result = model(text, min_length = minLength)
|
||||
full = ''.join(result)
|
||||
return full
|
||||
|
||||
''' SAMPLING '''
|
||||
def sentencenize(text):
|
||||
sentences = []
|
||||
for sent in text:
|
||||
sentences.append(sent_tokenize(sent))
|
||||
sentences = [y for x in sentences for y in x]
|
||||
return sentences
|
||||
|
||||
def extractWordVectors(file):
|
||||
word_embeddings = {}
|
||||
for line in file:
|
||||
values = line.split()
|
||||
word = values[0]
|
||||
coefs = np.asarray(values[1:], dtype='float32')
|
||||
word_embeddings[word] = coefs
|
||||
file.close()
|
||||
|
||||
def removeStopwords(sen, sw):
|
||||
sentence = " ".join([i for i in sen if i not in sw])
|
||||
return sentence
|
||||
|
||||
''' BERTABS '''
|
||||
model = Summarizer()
|
||||
|
||||
''' SAMPLING '''
|
||||
clean_sentences = [removeStopwords(r.split(), sw) for r in clean_sentences]
|
||||
|
||||
''' GENSIM '''
|
||||
summarize()
|
|
@ -88,7 +88,7 @@ if args.do_deploy:
|
|||
# Fetch Models
|
||||
models = []
|
||||
for task in tasks:
|
||||
model_name = f'{project_name}-model-t{task}'
|
||||
model_name = f'{project_name}-model-{task}' ####
|
||||
if int(task) == 3:
|
||||
# NOTE: task 3 does not have a model
|
||||
continue
|
||||
|
@ -99,9 +99,9 @@ if args.do_deploy:
|
|||
logging.warning(f'[INFO] Added Model : {model.name} (v{model.version})')
|
||||
|
||||
# Deployment Target
|
||||
memory_gb = 2
|
||||
memory_gb = 1
|
||||
if compute_type == 'ACI':
|
||||
compute_config = AciWebservice.deploy_configuration(cpu_cores=2, memory_gb=memory_gb, auth_enabled=auth_enabled)
|
||||
compute_config = AciWebservice.deploy_configuration(cpu_cores=1, memory_gb=memory_gb, auth_enabled=auth_enabled)
|
||||
elif compute_type == 'AKS':
|
||||
compute_config = AksWebservice.deploy_configuration()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче