Rename NameEntity to NamedEntity (#6917)
This commit is contained in:
Родитель
b8f71b9c66
Коммит
a60be5f215
|
@ -4,6 +4,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.ComponentModel;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.TorchSharp.NasBert
|
||||
|
@ -17,7 +18,10 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
MaskedLM = 1,
|
||||
TextClassification = 2,
|
||||
SentenceRegression = 3,
|
||||
NameEntityRecognition = 4,
|
||||
NamedEntityRecognition = 4,
|
||||
[Obsolete("Please use NamedEntityRecognition instead", false)]
|
||||
[EditorBrowsable(EditorBrowsableState.Never)]
|
||||
NameEntityRecognition = NamedEntityRecognition,
|
||||
QuestionAnswering = 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -204,7 +204,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();
|
||||
|
||||
NasBertModel model;
|
||||
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
|
||||
else
|
||||
model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
|
||||
|
@ -268,7 +268,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
|
||||
{
|
||||
Tensor logits = default;
|
||||
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
{
|
||||
int[,] lengthArray = new int[inputTensors.Count, 1];
|
||||
for (int i = 0; i < inputTensors.Count; i++)
|
||||
|
@ -293,7 +293,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
torch.Tensor loss;
|
||||
if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
|
||||
loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
|
||||
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
{
|
||||
targetsTensor = targetsTensor.@long().view(-1);
|
||||
logits = logits.view(-1, logits.size(-1));
|
||||
|
@ -338,7 +338,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
|
||||
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
|
||||
}
|
||||
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
{
|
||||
var metadata = new List<SchemaShape.Column>();
|
||||
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
|
||||
|
@ -387,7 +387,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
|
||||
}
|
||||
}
|
||||
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
{
|
||||
if (labelCol.ItemType != NumberDataViewType.UInt32)
|
||||
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
|
||||
|
@ -535,7 +535,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
|
||||
return info;
|
||||
}
|
||||
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
|
||||
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
|
||||
{
|
||||
var info = new DataViewSchema.DetachedColumn[1];
|
||||
var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <format type="text/markdown"><![CDATA[
|
||||
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NameEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
|
||||
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NamedEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
|
||||
///
|
||||
/// ### Input and Output Columns
|
||||
/// The input label column data must be a Vector of [string](xref:Microsoft.ML.Data.TextDataViewType) type and the sentence columns must be of type<xref:Microsoft.ML.Data.TextDataViewType>.
|
||||
|
@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
/// | Exportable to ONNX | No |
|
||||
///
|
||||
/// ### Training Algorithm Details
|
||||
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition.
|
||||
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition.
|
||||
/// ]]>
|
||||
/// </format>
|
||||
/// </remarks>
|
||||
|
@ -93,7 +93,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
BatchSize = batchSize,
|
||||
MaxEpoch = maxEpochs,
|
||||
ValidationSet = validationSet,
|
||||
TaskType = BertTaskType.NameEntityRecognition
|
||||
TaskType = BertTaskType.NamedEntityRecognition
|
||||
})
|
||||
{
|
||||
}
|
||||
|
@ -295,7 +295,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
|
||||
options.Sentence1ColumnName = ctx.LoadString();
|
||||
options.Sentence2ColumnName = ctx.LoadStringOrNull();
|
||||
options.TaskType = BertTaskType.NameEntityRecognition;
|
||||
options.TaskType = BertTaskType.NamedEntityRecognition;
|
||||
|
||||
BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
|
||||
DataViewType type;
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.ComponentModel;
|
||||
using System.Text;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.TorchSharp.AutoFormerV2;
|
||||
|
@ -161,7 +162,45 @@ namespace Microsoft.ML.TorchSharp
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
|
||||
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int, int, BertArchitecture, IDataView)"/> method instead
|
||||
/// </summary>
|
||||
/// <param name="catalog">The transform's catalog.</param>
|
||||
/// <param name="labelColumnName">Name of the label column. Column should be a key type.</param>
|
||||
/// <param name="outputColumnName">Name of the output column. It will be a key type. It is the predicted label.</param>
|
||||
/// <param name="sentence1ColumnName">Name of the column for the first sentence.</param>
|
||||
/// <param name="batchSize">Number of rows in the batch.</param>
|
||||
/// <param name="maxEpochs">Maximum number of times to loop through your training set.</param>
|
||||
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
|
||||
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
|
||||
/// <returns></returns>
|
||||
[Obsolete("Please use NamedEntityRecognition method instead", false)]
|
||||
[EditorBrowsable(EditorBrowsableState.Never)]
|
||||
public static NerTrainer NameEntityRecognition(
|
||||
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string outputColumnName = DefaultColumnNames.PredictedLabel,
|
||||
string sentence1ColumnName = "Sentence",
|
||||
int batchSize = 32,
|
||||
int maxEpochs = 10,
|
||||
BertArchitecture architecture = BertArchitecture.Roberta,
|
||||
IDataView validationSet = null)
|
||||
=> NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet);
|
||||
|
||||
/// <summary>
|
||||
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, NerTrainer.NerOptions)"/> method instead
|
||||
/// </summary>
|
||||
/// <param name="catalog">The transform's catalog.</param>
|
||||
/// <param name="options">The full set of advanced options.</param>
|
||||
/// <returns></returns>
|
||||
[Obsolete("Please use NamedEntityRecognition method instead", false)]
|
||||
[EditorBrowsable(EditorBrowsableState.Never)]
|
||||
public static NerTrainer NameEntityRecognition(
|
||||
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
NerTrainer.NerOptions options)
|
||||
=> NamedEntityRecognition(catalog, options);
|
||||
|
||||
/// <summary>
|
||||
/// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
|
||||
/// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token)
|
||||
/// so in general this limit will be 510 words for all sentences.
|
||||
/// </summary>
|
||||
|
@ -174,7 +213,7 @@ namespace Microsoft.ML.TorchSharp
|
|||
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
|
||||
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
|
||||
/// <returns></returns>
|
||||
public static NerTrainer NameEntityRecognition(
|
||||
public static NerTrainer NamedEntityRecognition(
|
||||
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string outputColumnName = DefaultColumnNames.PredictedLabel,
|
||||
|
@ -186,12 +225,12 @@ namespace Microsoft.ML.TorchSharp
|
|||
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture);
|
||||
|
||||
/// <summary>
|
||||
/// Fine tune a Name Entity Recognition model.
|
||||
/// Fine tune a Named Entity Recognition model.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The transform's catalog.</param>
|
||||
/// <param name="options">The full set of advanced options.</param>
|
||||
/// <returns></returns>
|
||||
public static NerTrainer NameEntityRecognition(
|
||||
public static NerTrainer NamedEntityRecognition(
|
||||
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
NerTrainer.NerOptions options)
|
||||
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), options);
|
||||
|
|
|
@ -54,7 +54,7 @@ namespace Microsoft.ML.Tests
|
|||
}));
|
||||
var chain = new EstimatorChain<ITransformer>();
|
||||
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
|
||||
.Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn"))
|
||||
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
|
||||
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
|
||||
|
||||
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
|
||||
|
|
Загрузка…
Ссылка в новой задаче