Minor fixes to train.py
This commit is contained in:
Родитель
91df1f9874
Коммит
777ed259dd
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from utils_nlp.model.gensen import train_mlflow
|
||||
from utils_nlp.model.gensen import train
|
||||
from utils_nlp.model.gensen.create_gensen_model import (
|
||||
create_multiseq2seq_model,
|
||||
)
|
||||
|
@ -112,7 +112,7 @@ class GenSenClassifier:
|
|||
self._validate_params()
|
||||
self.cache_dir = self._get_gensen_tokens(train_df, dev_df, test_df)
|
||||
|
||||
train_mlflow.train(
|
||||
train.train(
|
||||
data_folder=os.path.abspath(self.cache_dir),
|
||||
config=self.config,
|
||||
learning_rate=self.learning_rate,
|
||||
|
|
Загрузка…
Ссылка в новой задаче