diff --git a/requirements.txt b/requirements.txt index edd9c2b..8640511 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ opencensus-ext-azure==1 selenium==3.141.0 bs4==0.0. ##OPTIONAL LOCAL -azure-ai-textanalytics==1.0.0b3 #TO BE UPGRADED -> 5.0.0 -> TA +azure-ai-textanalytics==5.1.0 gensim==3.8.0 spacy==2.3.2 farm==0.4.7 diff --git a/src/ner.py b/src/ner.py index f23dd31..9a695db 100644 --- a/src/ner.py +++ b/src/ner.py @@ -23,6 +23,9 @@ from farm.modeling.tokenization import Tokenizer from farm.train import Trainer from farm.utils import set_all_seeds, initialize_device_settings +from azure.ai.textanalytics import TextAnalyticsClient +from azure.core.credentials import AzureKeyCredential + # Custom functions import sys sys.path.append('./src') @@ -62,15 +65,16 @@ class FlairMatcher(object): return doc class TextAnalyticsMatcher(object): - name = "textanalytics" def __init__(self): - self.endpoint = f"https://{he.get_secret('text-analytics-name')}.cognitiveservices.azure.com/text/analytics/v3.0/entities/recognition/general" - self.headers = {"Ocp-Apim-Subscription-Key": he.get_secret('text-analytics-key')} + key = he.get_secret('text-analytics-key') + self.endpoint = f"https://{he.get_secret('text-analytics-name')}.cognitiveservices.azure.com/" + self.key = AzureKeyCredential(key) + self.client = TextAnalyticsClient(endpoint = self.endpoint, credential = self.key) def __call__(self, doc): - result = requests.post(self.endpoint, headers=self.headers, json={"documents": [{"id": "0", "language": cu.params.get('language'), "text": doc.text}]}).json()['documents'][0] - for entity in result['entities']: - span = doc.char_span(entity['offset'], entity['offset'] + entity['length'], label = entity['category']) + result = self.client.recognize_entities(documents = [{"id": "0", "language": cu.params.get('language'), "text": doc.text}])[0] + for entity in result.entities: + span = doc.char_span(entity.offset, entity.offset + entity.length, label = entity.category) # Pass, in case a match already exists try: doc.ents = list(doc.ents) + [span]