Fixes NER to correctly expand/shrink the labels (#6928)
* ner options fix * Ner fixed. * Update src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs Co-authored-by: Eric StJohn <ericstj@microsoft.com> * fixes from PR comments * fixed build --------- Co-authored-by: Eric StJohn <ericstj@microsoft.com>
This commit is contained in:
Родитель
373a86467c
Коммит
8896dd2927
|
@ -66,7 +66,5 @@ namespace Microsoft.ML.Tokenizers
|
|||
/// <param name="ch"></param>
|
||||
/// <returns></returns>
|
||||
public abstract bool IsValidChar(char ch);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -240,7 +240,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
return DataUtils.CollateTokens(inputTensors, Tokenizer.RobertaModel().PadIndex, device: Device);
|
||||
}
|
||||
|
||||
private protected override torch.Tensor PrepareRowTensor()
|
||||
private protected override torch.Tensor PrepareRowTensor(ref TLabelCol target)
|
||||
{
|
||||
ReadOnlyMemory<char> sentence1 = default;
|
||||
Sentence1Getter(ref sentence1);
|
||||
|
@ -494,7 +494,8 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
|
||||
private static readonly FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
|
||||
= FuncInstanceMethodInfo1<NasBertMapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);
|
||||
|
||||
internal static readonly int[] InitTokenArray = new[] { 0 /* InitToken */ };
|
||||
internal static readonly int[] SeperatorTokenArray = new[] { 2 /* SeperatorToken */ };
|
||||
|
||||
public NasBertMapper(TorchSharpBaseTransformer<TLabelCol, TTargetsCol> parent, DataViewSchema inputSchema) :
|
||||
base(parent, inputSchema)
|
||||
|
@ -583,13 +584,16 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
getSentence1(ref sentence1);
|
||||
if (getSentence2 == default)
|
||||
{
|
||||
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString())).ToList();
|
||||
List<int> newList = new List<int>(tokenizer.EncodeToConverted(sentence1.ToString()));
|
||||
// 0 Is the init token and must be at the beginning.
|
||||
newList.Insert(0, 0);
|
||||
return newList;
|
||||
}
|
||||
else
|
||||
{
|
||||
getSentence2(ref sentence2);
|
||||
return new[] { 0 /* InitToken */ }.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
|
||||
.Concat(new[] { 2 /* SeperatorToken */ }).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
|
||||
return InitTokenArray.Concat(tokenizer.EncodeToConverted(sentence1.ToString()))
|
||||
.Concat(SeperatorTokenArray).Concat(tokenizer.EncodeToConverted(sentence2.ToString())).ToList();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ using Microsoft.ML.TorchSharp.NasBert;
|
|||
using Microsoft.ML.TorchSharp.NasBert.Models;
|
||||
using TorchSharp;
|
||||
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
|
||||
using static TorchSharp.torch;
|
||||
|
||||
[assembly: LoadableClass(typeof(NerTransformer), null, typeof(SignatureLoadModel),
|
||||
NerTransformer.UserName, NerTransformer.LoaderSignature)]
|
||||
|
@ -61,6 +62,8 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
///
|
||||
public class NerTrainer : NasBertTrainer<VBuffer<uint>, TargetType>
|
||||
{
|
||||
private const char StartChar = (char)(' ' + 256);
|
||||
|
||||
public class NerOptions : NasBertOptions
|
||||
{
|
||||
public NerOptions()
|
||||
|
@ -69,6 +72,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
EncoderOutputDim = 384;
|
||||
EmbeddingDim = 128;
|
||||
Arches = new int[] { 15, 16, 14, 0, 0, 0, 15, 16, 14, 0, 0, 0, 17, 14, 15, 0, 0, 0, 17, 14, 15, 0, 0, 0 };
|
||||
TaskType = BertTaskType.NamedEntityRecognition;
|
||||
}
|
||||
}
|
||||
internal NerTrainer(IHostEnvironment env, NerOptions options) : base(env, options)
|
||||
|
@ -93,7 +97,6 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
BatchSize = batchSize,
|
||||
MaxEpoch = maxEpochs,
|
||||
ValidationSet = validationSet,
|
||||
TaskType = BertTaskType.NamedEntityRecognition
|
||||
})
|
||||
{
|
||||
}
|
||||
|
@ -108,9 +111,12 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
return new NerTransformer(host, options as NasBertOptions, model as NasBertModel, labelColumn);
|
||||
}
|
||||
|
||||
internal static bool TokenStartsWithSpace(string token) => token is null || (token.Length != 0 && token[0] == StartChar);
|
||||
|
||||
private protected class Trainer : NasBertTrainerBase
|
||||
{
|
||||
private const string ModelUrlString = "models/pretrained_NasBert_14M_encoder.tsm";
|
||||
internal static readonly int[] ZeroArray = new int[] { 0 /* InitToken */};
|
||||
|
||||
public Trainer(TorchSharpBaseTrainer<VBuffer<uint>, TargetType> parent, IChannel ch, IDataView input) : base(parent, ch, input, ModelUrlString)
|
||||
{
|
||||
|
@ -155,6 +161,40 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
return torch.tensor(targetArray, device: Device);
|
||||
}
|
||||
|
||||
private protected override torch.Tensor PrepareRowTensor(ref VBuffer<uint> target)
|
||||
{
|
||||
ReadOnlyMemory<char> sentenceRom = default;
|
||||
Sentence1Getter(ref sentenceRom);
|
||||
var sentence = sentenceRom.ToString();
|
||||
Tensor t;
|
||||
var encoding = Tokenizer.Encode(sentence);
|
||||
|
||||
if (target.Length != encoding.Tokens.Count)
|
||||
{
|
||||
var targetIndex = 0;
|
||||
var targetEditor = VBufferEditor.Create(ref target, encoding.Tokens.Count);
|
||||
var newValues = targetEditor.Values;
|
||||
for (var i = 0; i < encoding.Tokens.Count; i++)
|
||||
{
|
||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
||||
{
|
||||
newValues[i] = target.GetItemOrDefault(++targetIndex);
|
||||
}
|
||||
else
|
||||
{
|
||||
newValues[i] = target.GetItemOrDefault(targetIndex);
|
||||
}
|
||||
}
|
||||
target = targetEditor.Commit();
|
||||
}
|
||||
t = torch.tensor((ZeroArray).Concat(Tokenizer.RobertaModel().IdsToOccurrenceRanks(encoding.Ids)).ToList(), device: Device);
|
||||
|
||||
if (t.NumberOfElements > 512)
|
||||
t = t.slice(0, 0, 512, 1);
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private protected override int GetNumCorrect(torch.Tensor predictions, torch.Tensor targets)
|
||||
{
|
||||
|
@ -334,6 +374,41 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
|
||||
}
|
||||
|
||||
private void CondenseOutput(ref VBuffer<UInt32> dst, string sentence, Tokenizer tokenizer, TensorCacher outputCacher)
|
||||
{
|
||||
var pre = tokenizer.PreTokenizer.PreTokenize(sentence);
|
||||
TokenizerResult encoding = tokenizer.Encode(sentence);
|
||||
|
||||
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
|
||||
var prediction = argmax.ToArray<long>();
|
||||
|
||||
var targetIndex = 0;
|
||||
// Figure out actual count of output tokens
|
||||
for (var i = 0; i < encoding.Tokens.Count; i++)
|
||||
{
|
||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
||||
{
|
||||
targetIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
var editor = VBufferEditor.Create(ref dst, targetIndex + 1);
|
||||
var newValues = editor.Values;
|
||||
targetIndex = 0;
|
||||
|
||||
newValues[targetIndex++] = (uint)prediction[0];
|
||||
|
||||
for (var i = 1; i < encoding.Tokens.Count; i++)
|
||||
{
|
||||
if (NerTrainer.TokenStartsWithSpace(encoding.Tokens[i]))
|
||||
{
|
||||
newValues[targetIndex++] = (uint)prediction[i];
|
||||
}
|
||||
}
|
||||
|
||||
dst = editor.Commit();
|
||||
}
|
||||
|
||||
private Delegate MakePredictedLabelGetter(DataViewRow input, IChannel ch, TensorCacher outputCacher)
|
||||
{
|
||||
ValueGetter<ReadOnlyMemory<char>> getSentence1 = default;
|
||||
|
@ -353,13 +428,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
|
|||
var argmax = (outputCacher as BertTensorCacher).Result.argmax(-1);
|
||||
var prediction = argmax.ToArray<long>();
|
||||
|
||||
var editor = VBufferEditor.Create(ref dst, prediction.Length - 1);
|
||||
for (int i = 1; i < prediction.Length; i++)
|
||||
{
|
||||
editor.Values[i - 1] = (uint)prediction[i];
|
||||
}
|
||||
|
||||
dst = editor.Commit();
|
||||
CondenseOutput(ref dst, sentence1.ToString(), tokenizer, outputCacher);
|
||||
};
|
||||
|
||||
return classification;
|
||||
|
|
|
@ -238,9 +238,9 @@ namespace Microsoft.ML.TorchSharp
|
|||
cursorValid = cursor.MoveNext();
|
||||
if (cursorValid)
|
||||
{
|
||||
inputTensors.Add(PrepareRowTensor());
|
||||
TLabelCol target = default;
|
||||
labelGetter(ref target);
|
||||
inputTensors.Add(PrepareRowTensor(ref target));
|
||||
targets.Add(AddToTargets(target));
|
||||
}
|
||||
else
|
||||
|
@ -312,9 +312,9 @@ namespace Microsoft.ML.TorchSharp
|
|||
cursorValid = cursor.MoveNext();
|
||||
if (cursorValid)
|
||||
{
|
||||
inputTensors.Add(PrepareRowTensor());
|
||||
TLabelCol target = default;
|
||||
labelGetter(ref target);
|
||||
inputTensors.Add(PrepareRowTensor(ref target));
|
||||
targets.Add(AddToTargets(target));
|
||||
}
|
||||
else
|
||||
|
@ -343,7 +343,7 @@ namespace Microsoft.ML.TorchSharp
|
|||
|
||||
private protected abstract void RunModelAndBackPropagate(ref List<Tensor> inputTensorm, ref Tensor targetsTensor);
|
||||
|
||||
private protected abstract torch.Tensor PrepareRowTensor();
|
||||
private protected abstract torch.Tensor PrepareRowTensor(ref TLabelCol target);
|
||||
private protected abstract torch.Tensor PrepareBatchTensor(ref List<Tensor> inputTensors, Device device);
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.RunTests;
|
||||
using Microsoft.ML.TorchSharp;
|
||||
using Microsoft.ML.TorchSharp.NasBert;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
|
@ -36,21 +38,34 @@ namespace Microsoft.ML.Tests
|
|||
new[] {
|
||||
new Label { Key = "PERSON" },
|
||||
new Label { Key = "CITY" },
|
||||
new Label { Key = "COUNTRY" }
|
||||
new Label { Key = "COUNTRY" },
|
||||
new Label { Key = "B_WORK_OF_ART" },
|
||||
new Label { Key = "WORK_OF_ART" },
|
||||
new Label { Key = "B_NORP" },
|
||||
});
|
||||
|
||||
var dataView = ML.Data.LoadFromEnumerable(
|
||||
new List<TestSingleSentenceData>(new TestSingleSentenceData[] {
|
||||
new TestSingleSentenceData()
|
||||
{ // Testing longer than 512 words.
|
||||
Sentence = "Alice and Bob live in the USA",
|
||||
Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"}
|
||||
new()
|
||||
{
|
||||
Sentence = "Alice and Bob live in the liechtenstein",
|
||||
Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY" }
|
||||
},
|
||||
new TestSingleSentenceData()
|
||||
new()
|
||||
{
|
||||
Sentence = "Alice and Bob live in the USA",
|
||||
Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"}
|
||||
},
|
||||
new()
|
||||
{
|
||||
Sentence = "WW II Landmarks on the Great Earth of China : Eternal Memories of Taihang Mountain",
|
||||
Label = new string[]{"B_WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART" }
|
||||
},
|
||||
new()
|
||||
{
|
||||
Sentence = "This campaign broke through the Japanese army 's blockade to reach base areas behind enemy lines , stirring up anti-Japanese spirit throughout the nation and influencing the situation of the anti-fascist war of the people worldwide .",
|
||||
Label = new string[]{"0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }
|
||||
}
|
||||
}));
|
||||
var chain = new EstimatorChain<ITransformer>();
|
||||
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
|
||||
|
@ -68,8 +83,183 @@ namespace Microsoft.ML.Tests
|
|||
Assert.Equal(5, transformerSchema.Count);
|
||||
Assert.Equal("outputColumn", transformerSchema[4].Name);
|
||||
|
||||
var output = transformer.Transform(dataView);
|
||||
var cursor = output.GetRowCursorForAllColumns();
|
||||
|
||||
var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
|
||||
var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
|
||||
|
||||
VBuffer<uint> labelData = default;
|
||||
VBuffer<uint> predictedLabelData = default;
|
||||
|
||||
while (cursor.MoveNext())
|
||||
{
|
||||
labelGetter(ref labelData);
|
||||
predictedLabelGetter(ref predictedLabelData);
|
||||
|
||||
// Make sure that the expected label and the predicted label have same length
|
||||
Assert.Equal(labelData.Length, predictedLabelData.Length);
|
||||
}
|
||||
|
||||
TestEstimatorCore(estimator, dataView, shouldDispose: true);
|
||||
transformer.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestSimpleNerOptions()
|
||||
{
|
||||
var labels = ML.Data.LoadFromEnumerable(
|
||||
new[] {
|
||||
new Label { Key = "PERSON" },
|
||||
new Label { Key = "CITY" },
|
||||
new Label { Key = "COUNTRY" },
|
||||
new Label { Key = "B_WORK_OF_ART" },
|
||||
new Label { Key = "WORK_OF_ART" },
|
||||
new Label { Key = "B_NORP" },
|
||||
});
|
||||
|
||||
var options = new NerTrainer.NerOptions();
|
||||
options.PredictionColumnName = "outputColumn";
|
||||
|
||||
var dataView = ML.Data.LoadFromEnumerable(
|
||||
new List<TestSingleSentenceData>(new TestSingleSentenceData[] {
|
||||
new()
|
||||
{
|
||||
Sentence = "Alice and Bob live in the liechtenstein",
|
||||
Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY" }
|
||||
},
|
||||
new()
|
||||
{
|
||||
Sentence = "Alice and Bob live in the USA",
|
||||
Label = new string[]{"PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"}
|
||||
},
|
||||
new()
|
||||
{
|
||||
Sentence = "WW II Landmarks on the Great Earth of China : Eternal Memories of Taihang Mountain",
|
||||
Label = new string[]{"B_WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART", "WORK_OF_ART" }
|
||||
},
|
||||
new()
|
||||
{
|
||||
Sentence = "This campaign broke through the Japanese army 's blockade to reach base areas behind enemy lines , stirring up anti-Japanese spirit throughout the nation and influencing the situation of the anti-fascist war of the people worldwide .",
|
||||
Label = new string[]{"0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "B_NORP", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0", "0" }
|
||||
}
|
||||
}));
|
||||
var chain = new EstimatorChain<ITransformer>();
|
||||
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
|
||||
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(options))
|
||||
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
|
||||
|
||||
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
|
||||
Assert.Equal(3, estimatorSchema.Count);
|
||||
Assert.Equal("outputColumn", estimatorSchema[2].Name);
|
||||
Assert.Equal(TextDataViewType.Instance, estimatorSchema[2].ItemType);
|
||||
|
||||
var transformer = estimator.Fit(dataView);
|
||||
var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
|
||||
|
||||
Assert.Equal(5, transformerSchema.Count);
|
||||
Assert.Equal("outputColumn", transformerSchema[4].Name);
|
||||
|
||||
var output = transformer.Transform(dataView);
|
||||
var cursor = output.GetRowCursorForAllColumns();
|
||||
|
||||
var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
|
||||
var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
|
||||
|
||||
VBuffer<uint> labelData = default;
|
||||
VBuffer<uint> predictedLabelData = default;
|
||||
|
||||
while (cursor.MoveNext())
|
||||
{
|
||||
labelGetter(ref labelData);
|
||||
predictedLabelGetter(ref predictedLabelData);
|
||||
|
||||
// Make sure that the expected label and the predicted label have same length
|
||||
Assert.Equal(labelData.Length, predictedLabelData.Length);
|
||||
}
|
||||
|
||||
TestEstimatorCore(estimator, dataView, shouldDispose: true);
|
||||
transformer.Dispose();
|
||||
}
|
||||
|
||||
[Fact(Skip = "Needs to be on a comp with GPU or will take a LONG time.")]
|
||||
public void TestNERLargeFileGpu()
|
||||
{
|
||||
ML.FallbackToCpu = false;
|
||||
ML.GpuDeviceId = 0;
|
||||
|
||||
var labelFilePath = GetDataPath("ner-key-info.txt");
|
||||
var labels = ML.Data.LoadFromTextFile(labelFilePath, new TextLoader.Column[]
|
||||
{
|
||||
new TextLoader.Column("Key", DataKind.String, 0)
|
||||
}
|
||||
);
|
||||
|
||||
var dataFilePath = GetDataPath("ner-conll2012_english_v4_train.txt");
|
||||
var dataView = TextLoader.Create(ML, new TextLoader.Options()
|
||||
{
|
||||
Columns = new[]
|
||||
{
|
||||
new TextLoader.Column("Sentence", DataKind.String, 0),
|
||||
new TextLoader.Column("Label", DataKind.String, new TextLoader.Range[]
|
||||
{
|
||||
new TextLoader.Range(1, null) { VariableEnd = true, AutoEnd = false }
|
||||
})
|
||||
},
|
||||
HasHeader = false,
|
||||
Separators = new char[] { '\t' },
|
||||
MaxRows = 75187 // Dataset has 75187 rows. Only load 1k for quicker training,
|
||||
}, new MultiFileSource(dataFilePath));
|
||||
|
||||
var trainTest = ML.Data.TrainTestSplit(dataView);
|
||||
|
||||
var options = new NerTrainer.NerOptions();
|
||||
options.PredictionColumnName = "outputColumn";
|
||||
|
||||
var chain = new EstimatorChain<ITransformer>();
|
||||
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
|
||||
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(options))
|
||||
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
|
||||
|
||||
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));
|
||||
Assert.Equal(3, estimatorSchema.Count);
|
||||
Assert.Equal("outputColumn", estimatorSchema[2].Name);
|
||||
Assert.Equal(TextDataViewType.Instance, estimatorSchema[2].ItemType);
|
||||
|
||||
var transformer = estimator.Fit(trainTest.TrainSet);
|
||||
var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
|
||||
|
||||
var output = transformer.Transform(trainTest.TrainSet);
|
||||
var cursor = output.GetRowCursorForAllColumns();
|
||||
|
||||
var labelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[2]);
|
||||
var predictedLabelGetter = cursor.GetGetter<VBuffer<uint>>(output.Schema[3]);
|
||||
|
||||
VBuffer<uint> labelData = default;
|
||||
VBuffer<uint> predictedLabelData = default;
|
||||
|
||||
double correct = 0;
|
||||
double total = 0;
|
||||
|
||||
while (cursor.MoveNext())
|
||||
{
|
||||
labelGetter(ref labelData);
|
||||
predictedLabelGetter(ref predictedLabelData);
|
||||
|
||||
Assert.Equal(labelData.Length, predictedLabelData.Length);
|
||||
|
||||
for (var i = 0; i < labelData.Length; i++)
|
||||
{
|
||||
if (labelData.GetItemOrDefault(i) == predictedLabelData.GetItemOrDefault(i) || (labelData.GetItemOrDefault(i) == default && predictedLabelData.GetItemOrDefault(i) == 0))
|
||||
correct++;
|
||||
total++;
|
||||
}
|
||||
}
|
||||
Assert.True(correct / total > .80);
|
||||
Assert.Equal(5, transformerSchema.Count);
|
||||
Assert.Equal("outputColumn", transformerSchema[4].Name);
|
||||
|
||||
transformer.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,36 @@
|
|||
B_PERSON
|
||||
PERSON
|
||||
B_NORP
|
||||
NORP
|
||||
B_FAC
|
||||
FAC
|
||||
B_ORG
|
||||
ORG
|
||||
B_GPE
|
||||
GPE
|
||||
B_LOC
|
||||
LOC
|
||||
B_PRODUCT
|
||||
PRODUCT
|
||||
B_DATE
|
||||
DATE
|
||||
B_TIME
|
||||
TIME
|
||||
B_PERCENT
|
||||
PERCENT
|
||||
B_MONEY
|
||||
MONEY
|
||||
B_QUANTITY
|
||||
QUANTITY
|
||||
B_ORDINAL
|
||||
ORDINAL
|
||||
B_CARDINAL
|
||||
CARDINAL
|
||||
B_EVENT
|
||||
EVENT
|
||||
B_WORK_OF_ART
|
||||
WORK_OF_ART
|
||||
B_LAW
|
||||
LAW
|
||||
B_LANGUAGE
|
||||
LANGUAGE
|
Загрузка…
Ссылка в новой задаче