Rename NameEntity to NamedEntity (#6917)

This commit is contained in:
Eric StJohn 2023-12-21 10:08:50 -08:00 коммит произвёл GitHub
Родитель b8f71b9c66
Коммит a60be5f215
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 59 добавлений и 16 удалений

Просмотреть файл

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;
namespace Microsoft.ML.TorchSharp.NasBert
@ -17,7 +18,10 @@ namespace Microsoft.ML.TorchSharp.NasBert
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3,
NameEntityRecognition = 4,
NamedEntityRecognition = 4,
[Obsolete("Please use NamedEntityRecognition instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
NameEntityRecognition = NamedEntityRecognition,
QuestionAnswering = 5
}
}

Просмотреть файл

@ -204,7 +204,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();
NasBertModel model;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
else
model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
@ -268,7 +268,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
{
Tensor logits = default;
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
int[,] lengthArray = new int[inputTensors.Count, 1];
for (int i = 0; i < inputTensors.Count; i++)
@ -293,7 +293,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
torch.Tensor loss;
if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
targetsTensor = targetsTensor.@long().view(-1);
logits = logits.view(-1, logits.size(-1));
@ -338,7 +338,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var metadata = new List<SchemaShape.Column>();
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
@ -387,7 +387,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
}
}
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
if (labelCol.ItemType != NumberDataViewType.UInt32)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
@ -535,7 +535,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
return info;
}
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
{
var info = new DataViewSchema.DetachedColumn[1];
var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;

Просмотреть файл

@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NameEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NamedEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
///
/// ### Input and Output Columns
/// The input label column data must be a Vector of [string](xref:Microsoft.ML.Data.TextDataViewType) type and the sentence columns must be of type<xref:Microsoft.ML.Data.TextDataViewType>.
@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
/// | Exportable to ONNX | No |
///
/// ### Training Algorithm Details
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition.
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition.
/// ]]>
/// </format>
/// </remarks>
@ -93,7 +93,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
BatchSize = batchSize,
MaxEpoch = maxEpochs,
ValidationSet = validationSet,
TaskType = BertTaskType.NameEntityRecognition
TaskType = BertTaskType.NamedEntityRecognition
})
{
}
@ -295,7 +295,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
options.Sentence1ColumnName = ctx.LoadString();
options.Sentence2ColumnName = ctx.LoadStringOrNull();
options.TaskType = BertTaskType.NameEntityRecognition;
options.TaskType = BertTaskType.NamedEntityRecognition;
BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
DataViewType type;

Просмотреть файл

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Text;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp.AutoFormerV2;
@ -161,7 +162,45 @@ namespace Microsoft.ML.TorchSharp
}
/// <summary>
/// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int, int, BertArchitecture, IDataView)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="labelColumnName">Name of the label column. Column should be a key type.</param>
/// <param name="outputColumnName">Name of the output column. It will be a key type. It is the predicted label.</param>
/// <param name="sentence1ColumnName">Name of the column for the first sentence.</param>
/// <param name="batchSize">Number of rows in the batch.</param>
/// <param name="maxEpochs">Maximum number of times to loop through your training set.</param>
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
string sentence1ColumnName = "Sentence",
int batchSize = 32,
int maxEpochs = 10,
BertArchitecture architecture = BertArchitecture.Roberta,
IDataView validationSet = null)
=> NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet);
/// <summary>
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, NerTrainer.NerOptions)"/> method instead
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
[Obsolete("Please use NamedEntityRecognition method instead", false)]
[EditorBrowsable(EditorBrowsableState.Never)]
public static NerTrainer NameEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> NamedEntityRecognition(catalog, options);
/// <summary>
/// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
/// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token)
/// so in general this limit will be 510 words for all sentences.
/// </summary>
@ -174,7 +213,7 @@ namespace Microsoft.ML.TorchSharp
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
string labelColumnName = DefaultColumnNames.Label,
string outputColumnName = DefaultColumnNames.PredictedLabel,
@ -186,12 +225,12 @@ namespace Microsoft.ML.TorchSharp
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture);
/// <summary>
/// Fine tune a Name Entity Recognition model.
/// Fine tune a Named Entity Recognition model.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The full set of advanced options.</param>
/// <returns></returns>
public static NerTrainer NameEntityRecognition(
public static NerTrainer NamedEntityRecognition(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NerTrainer.NerOptions options)
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), options);

Просмотреть файл

@ -54,7 +54,7 @@ namespace Microsoft.ML.Tests
}));
var chain = new EstimatorChain<ITransformer>();
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
.Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));