зеркало из https://github.com/microsoft/mwt-ds.git
Родитель
4da2e733d0
Коммит
5069f25dab
|
@ -67,6 +67,26 @@
|
|||
</Reference>
|
||||
<Reference Include="System" />
|
||||
<Reference Include="System.Core" />
|
||||
<Reference Include="System.Reactive.Core, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Core.3.0.0\lib\net45\System.Reactive.Core.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Reactive.Interfaces, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Interfaces.3.0.0\lib\net45\System.Reactive.Interfaces.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Reactive.Linq, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Linq.3.0.0\lib\net45\System.Reactive.Linq.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Reactive.PlatformServices, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.PlatformServices.3.0.0\lib\net45\System.Reactive.PlatformServices.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Reactive.Windows.Threading, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Windows.Threading.3.0.0\lib\net45\System.Reactive.Windows.Threading.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Spatial, Version=5.6.4.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Spatial.5.6.4\lib\net40\System.Spatial.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
|
@ -81,26 +101,30 @@
|
|||
<Reference Include="System.Data" />
|
||||
<Reference Include="System.Net.Http" />
|
||||
<Reference Include="System.Xml" />
|
||||
<Reference Include="VowpalWabbit, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.Common, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Common.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.Common, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Common.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.Core, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Core.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.Core, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Core.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.13\lib\net45\VowpalWabbit.JSON.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.16\lib\net45\VowpalWabbit.JSON.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="WindowsBase" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Compile Include="AzureBlobDownloader.cs" />
|
||||
<Compile Include="FileTransformBlock.cs" />
|
||||
<Compile Include="VowpalWabbitJsonToString.cs" />
|
||||
<Compile Include="JsonTransform.cs" />
|
||||
<Compile Include="Metrics.cs" />
|
||||
<Compile Include="OfflineTrainer.cs" />
|
||||
<Compile Include="Properties\AssemblyInfo.cs" />
|
||||
</ItemGroup>
|
||||
|
@ -109,12 +133,12 @@
|
|||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
|
||||
<Import Project="..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" />
|
||||
<Import Project="..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" />
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets'))" />
|
||||
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets'))" />
|
||||
</Target>
|
||||
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
|
||||
Other similar extension points exist, see Microsoft.Common.targets.
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Dataflow;
|
||||
|
||||
namespace Experimentation
|
||||
{
|
||||
public class Line
|
||||
{
|
||||
public string Content { get; set; }
|
||||
|
||||
public int Number { get; set; }
|
||||
}
|
||||
|
||||
|
||||
public static class FileTransformBlock
|
||||
{
|
||||
public static ISourceBlock<T> Create<T>(string file, Func<Line, T> transform)
|
||||
{
|
||||
var block = new TransformBlock<Line, T>(
|
||||
transform,
|
||||
new ExecutionDataflowBlockOptions
|
||||
{
|
||||
MaxDegreeOfParallelism = 2,
|
||||
BoundedCapacity = 128
|
||||
});
|
||||
|
||||
new Thread(() =>
|
||||
{
|
||||
var input = block.AsObserver();
|
||||
try
|
||||
{
|
||||
using (var reader = new StreamReader(file))
|
||||
{
|
||||
string line;
|
||||
int lineNr = 0;
|
||||
var batch = new List<Line>();
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
input.OnNext(new Line { Content = line, Number = lineNr++ });
|
||||
|
||||
input.OnCompleted();
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
input.OnError(ex);
|
||||
}
|
||||
}).Start();
|
||||
|
||||
return block;
|
||||
}
|
||||
|
||||
public static ISourceBlock<List<T>> CreateBatch<T>(string file, Func<Line, T> transform)
|
||||
{
|
||||
var block = new TransformBlock<List<Line>, List<T>>(
|
||||
batch => batch.Select(transform).ToList(),
|
||||
new ExecutionDataflowBlockOptions
|
||||
{
|
||||
MaxDegreeOfParallelism = 8,
|
||||
BoundedCapacity = 128
|
||||
});
|
||||
|
||||
new Thread(() =>
|
||||
{
|
||||
var input = block.AsObserver();
|
||||
try
|
||||
{
|
||||
using (var reader = new StreamReader(file))
|
||||
{
|
||||
string line;
|
||||
int lineNr = 0;
|
||||
var batch = new List<Line>();
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
{
|
||||
batch.Add(new Line { Content = line, Number = lineNr++ });
|
||||
if (batch.Count >= 512)
|
||||
{
|
||||
input.OnNext(batch);
|
||||
batch = new List<Line>();
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.Count > 0)
|
||||
input.OnNext(batch);
|
||||
|
||||
input.OnCompleted();
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
input.OnError(ex);
|
||||
}
|
||||
}).Start();
|
||||
|
||||
return block;
|
||||
// var order = new TransformBlock<List<T>, List<T>>(t => t, new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1 });
|
||||
// block.LinkTo(order, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
|
||||
// return order;
|
||||
}
|
||||
|
||||
public static ISourceBlock<IList<T>> Batch<T>(int n, ISourceBlock<T> source)
|
||||
{
|
||||
var block = new BatchBlock<T>(n);
|
||||
source.LinkTo(block, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
return block;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ namespace Experimentation
|
|||
{
|
||||
public static class JsonTransform
|
||||
{
|
||||
public static void TransformIngoreProperties(string fileIn, string fileOut, params string[] propertiesToIgnore)
|
||||
public static void TransformIgnoreProperties(string fileIn, string fileOut, params string[] propertiesToIgnore)
|
||||
{
|
||||
var ignorePropertiesSet = new HashSet<string>(propertiesToIgnore);
|
||||
Transform(fileIn, fileOut, (reader, writer) =>
|
||||
|
@ -29,8 +29,8 @@ namespace Experimentation
|
|||
|
||||
public static void Transform(string fileIn, string fileOut, Func<JsonTextReader, JsonTextWriter, bool> transform)
|
||||
{
|
||||
using (var reader = new StreamReader(fileIn))
|
||||
using (var writer = new StreamWriter(fileOut))
|
||||
using (var reader = new StreamReader(fileIn, Encoding.UTF8))
|
||||
using (var writer = new StreamWriter(fileOut, false, Encoding.UTF8))
|
||||
{
|
||||
var transformBlock = new TransformBlock<string, string>(
|
||||
evt =>
|
||||
|
@ -54,7 +54,7 @@ namespace Experimentation
|
|||
MaxDegreeOfParallelism = 8 // TODO:parameterize
|
||||
});
|
||||
|
||||
var outputBock = new ActionBlock<string>(l => writer.WriteLine(l),
|
||||
var outputBock = new ActionBlock<string>(l => { if (!string.IsNullOrEmpty(l)) writer.WriteLine(l); },
|
||||
new ExecutionDataflowBlockOptions { BoundedCapacity = 1024, MaxDegreeOfParallelism = 1 });
|
||||
transformBlock.LinkTo(outputBock, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
using Newtonsoft.Json;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reactive.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Dataflow;
|
||||
using VW;
|
||||
|
||||
namespace Experimentation
|
||||
{
|
||||
internal class Label
|
||||
{
|
||||
[JsonProperty("as", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal int[] Actions;
|
||||
|
||||
/// <summary>
|
||||
/// Ordered for Actions
|
||||
/// </summary>
|
||||
[JsonProperty("p", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal float[] Probabilities;
|
||||
|
||||
[JsonProperty("a", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal int? Action;
|
||||
|
||||
[JsonProperty("ac", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal int? ActionCount;
|
||||
|
||||
[JsonProperty("pr", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal float? Probability;
|
||||
|
||||
/// <summary>
|
||||
/// Ordred 0...n
|
||||
/// </summary>
|
||||
[JsonIgnore]
|
||||
internal float[] ProbabilitiesOrdered;
|
||||
|
||||
[JsonProperty("pd", NullValueHandling = NullValueHandling.Ignore)]
|
||||
internal float? ProbabilityOfDrop = 0f;
|
||||
|
||||
[JsonProperty("c")]
|
||||
internal float Cost;
|
||||
|
||||
[JsonIgnore]
|
||||
internal int LineNr;
|
||||
}
|
||||
|
||||
public static class Metrics
|
||||
{
|
||||
private static IEnumerable<Label> ParseCacheLabels(string labelFile)
|
||||
{
|
||||
//return FileTransformBlock.CreateBatch(
|
||||
// labelFile,
|
||||
// l =>
|
||||
// {
|
||||
// var label = JsonConvert.DeserializeObject<Label>(l.Content);
|
||||
|
||||
// if (label.Probabilities == null)
|
||||
// {
|
||||
// // reconstruct action/probabilities based on action/probDeprecated
|
||||
// label.Probabilities = Enumerable.Repeat((float)label.Probability, (int)label.ActionCount).ToArray();
|
||||
|
||||
// label.Actions = new int[(int)(int)label.ActionCount];
|
||||
// label.Actions[0] = (int)label.Action;
|
||||
// for (int i = 1, j = 1; i < (int)label.ActionCount; i++, j++)
|
||||
// {
|
||||
// if (j == (int)label.ActionCount)
|
||||
// j++;
|
||||
|
||||
// label.Actions[i] = j;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // order probs by action
|
||||
// label.ProbabilitiesOrdered = new float[label.Actions.Length];
|
||||
// for (int i = 0; i < label.Actions.Length; i++)
|
||||
// label.ProbabilitiesOrdered[label.Actions[i] - 1] = label.Probabilities[i];
|
||||
|
||||
// label.LineNr = l.Number;
|
||||
// return label;
|
||||
// });
|
||||
|
||||
using (var reader = new StreamReader(labelFile))
|
||||
{
|
||||
string line;
|
||||
int lineNr = 0;
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
{
|
||||
var label = JsonConvert.DeserializeObject<Label>(line);
|
||||
|
||||
if (label.Probabilities == null)
|
||||
{
|
||||
// reconstruct action/probabilities based on action/probDeprecated
|
||||
label.Probabilities = Enumerable.Repeat((float)label.Probability, (int)label.ActionCount).ToArray();
|
||||
|
||||
label.Actions = new int[(int)(int)label.ActionCount];
|
||||
label.Actions[0] = (int)label.Action;
|
||||
for (int i = 1, j = 1; i < (int)label.ActionCount; i++, j++)
|
||||
{
|
||||
if (j == (int)label.ActionCount)
|
||||
j++;
|
||||
|
||||
label.Actions[i] = j;
|
||||
}
|
||||
}
|
||||
|
||||
// order probs by action
|
||||
label.ProbabilitiesOrdered = new float[label.Actions.Length];
|
||||
for (int i = 0; i < label.Actions.Length; i++)
|
||||
label.ProbabilitiesOrdered[label.Actions[i] - 1] = label.Probabilities[i];
|
||||
|
||||
label.LineNr = lineNr++;
|
||||
yield return label;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static IEnumerable<Label> ExtractAndCacheLabels(string data)
|
||||
{
|
||||
var labelFile = data + ".labels";
|
||||
|
||||
if (!File.Exists(labelFile))
|
||||
{
|
||||
Console.WriteLine($"Building label cache file {labelFile}");
|
||||
var jsonSerializer = new JsonSerializer();
|
||||
|
||||
using (var writer = new StreamWriter(labelFile))
|
||||
using (var reader = new StreamReader(data))
|
||||
{
|
||||
string line;
|
||||
int lineNr = 0;
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
{
|
||||
int? action = null;
|
||||
float? cost = null;
|
||||
float[] probabilities = null;
|
||||
int[] actions = null;
|
||||
float? probabilityOfDrop = null;
|
||||
float? probDeprecated = null;
|
||||
int? actionCount = null;
|
||||
|
||||
using (var jsonReader = new JsonTextReader(new StringReader(line)))
|
||||
{
|
||||
while (jsonReader.Read())
|
||||
{
|
||||
if (jsonReader.TokenType == JsonToken.PropertyName)
|
||||
{
|
||||
if ("_label_action".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
action = jsonReader.ReadAsInt32();
|
||||
else if ("_label_cost".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
cost = (float)jsonReader.ReadAsDouble();
|
||||
else if ("_label_probability".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
probDeprecated = (float)jsonReader.ReadAsDouble();
|
||||
else if ("_a".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
actions = jsonSerializer.Deserialize<int[]>(jsonReader);
|
||||
else if ("_p".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
probabilities = jsonSerializer.Deserialize<float[]>(jsonReader);
|
||||
else if ("_ProbabilityOfDrop".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
probabilityOfDrop = (float)jsonReader.ReadAsDouble();
|
||||
else if ("_multi".Equals((string)jsonReader.Value, StringComparison.Ordinal))
|
||||
{
|
||||
if (!jsonReader.Read() && jsonReader.TokenType == JsonToken.StartArray)
|
||||
{
|
||||
throw new InvalidDataException($"Unexpected type for _multi: {jsonReader.TokenType}");
|
||||
}
|
||||
|
||||
// count _multi elements
|
||||
actionCount = 0;
|
||||
while (jsonReader.Read() && jsonReader.TokenType == JsonToken.StartObject)
|
||||
{
|
||||
jsonReader.Skip();
|
||||
actionCount++;
|
||||
}
|
||||
|
||||
if (jsonReader.TokenType != JsonToken.EndArray)
|
||||
{
|
||||
throw new InvalidDataException($"Unexpected type for _multi: {jsonReader.TokenType}");
|
||||
}
|
||||
}
|
||||
|
||||
if (actions != null && cost != null && probabilities != null)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((actions == null && action == null) || cost == null || (probabilities == null && probDeprecated == null))
|
||||
throw new InvalidDataException("Missing label in line " + lineNr);
|
||||
|
||||
var label = new Label
|
||||
{
|
||||
Cost = (float)cost,
|
||||
ProbabilityOfDrop = probabilityOfDrop
|
||||
};
|
||||
|
||||
if (actions == null || probabilities == null)
|
||||
{
|
||||
label.Action = action;
|
||||
label.Probability = probDeprecated;
|
||||
label.ActionCount = actionCount;
|
||||
}
|
||||
else
|
||||
{
|
||||
label.Actions = actions;
|
||||
label.Probabilities = probabilities;
|
||||
}
|
||||
|
||||
writer.WriteLine(JsonConvert.SerializeObject(label));
|
||||
|
||||
lineNr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ParseCacheLabels(labelFile);
|
||||
}
|
||||
|
||||
internal class Data
|
||||
{
|
||||
[JsonProperty("nr")]
|
||||
public int LineNr { get; set; }
|
||||
|
||||
[JsonProperty("as")]
|
||||
public int[] Actions { get; set; }
|
||||
|
||||
[JsonProperty("a")]
|
||||
public int Action { get; set; }
|
||||
|
||||
[JsonProperty("p")]
|
||||
public float[] Probabilities { get; set; }
|
||||
|
||||
public float GetLoss(Label label)
|
||||
{
|
||||
if (Actions == null)
|
||||
return VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)Action, label.Cost, label.Probabilities[0]);
|
||||
else
|
||||
{
|
||||
var c = Actions.Zip(Probabilities, (action, prob) => new { Action = action, Prob = prob })
|
||||
.Sum(ap => ap.Prob * VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)ap.Action + 1, label.Cost, label.Probabilities[0]));
|
||||
|
||||
var p = Actions.Zip(Probabilities, (action, prob) => new { Action = action, Prob = prob })
|
||||
.Sum(ap => ap.Prob / (label.Probabilities[ap.Action] * (1 - label.ProbabilityOfDrop ?? 0)));
|
||||
|
||||
// SUM(cost) / SUM(1 / prob)
|
||||
return c / (1f / p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void Compute(string dataFile, params string[] predictions)
|
||||
{
|
||||
// deserialize in parallel
|
||||
// run all computation in parallel
|
||||
// average
|
||||
|
||||
//var labelSource = ExtractAndCacheLabels(dataFile);
|
||||
//var messages = new List<IObservable<Tuple<string, float>>>();
|
||||
//for (int i = 0; i < predictions.Length; i++)
|
||||
//{
|
||||
// var predFile = predictions[i];
|
||||
// var pred = FileTransformBlock.CreateBatch(predFile, line => JsonConvert.DeserializeObject<Data>(line.Content));
|
||||
|
||||
// var joinBlock = new JoinBlock<List<Label>, List<Data>>(new GroupingDataflowBlockOptions { Greedy = false, BoundedCapacity = 128 });
|
||||
// labelSource.LinkTo(joinBlock.Target1, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
// pred.LinkTo(joinBlock.Target2, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
|
||||
// var lossBlock = new TransformBlock<Tuple<List<Label>, List<Data>>, float>(
|
||||
// b =>
|
||||
// {
|
||||
// if (b.Item1.Count != b.Item2.Count)
|
||||
// throw new InvalidDataException();
|
||||
|
||||
// var sum = 0f;
|
||||
// for (int j = 0; j < b.Item1.Count; j++)
|
||||
// {
|
||||
// var label = b.Item1[j];
|
||||
// var data = b.Item2[j];
|
||||
|
||||
// if (label.LineNr != data.LineNr)
|
||||
// throw new InvalidDataException($"Label line nr {label.LineNr} does not match prediction line number {data.LineNr}");
|
||||
|
||||
// if (data.Actions == null)
|
||||
// sum += VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)data.Action, label.Cost, label.Probabilities[0]);
|
||||
// else
|
||||
// {
|
||||
// var c = data.Actions.Zip(data.Probabilities, (action, prob) => new { Action = action, Prob = prob })
|
||||
// .Sum(ap => ap.Prob * VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)ap.Action + 1, label.Cost, label.Probabilities[0]));
|
||||
|
||||
// var p = data.Actions.Zip(data.Probabilities, (action, prob) => new { Action = action, Prob = prob })
|
||||
// .Sum(ap => ap.Prob / (label.Probabilities[ap.Action] * (1 - label.ProbabilityOfDrop ?? 0)));
|
||||
|
||||
// // SUM(cost) / SUM(1 / prob)
|
||||
// sum += c / (1f / p);
|
||||
// }
|
||||
// }
|
||||
|
||||
// return sum;
|
||||
// },
|
||||
// new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 4 });
|
||||
|
||||
// joinBlock.LinkTo(lossBlock, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
|
||||
// messages.Add(lossBlock.AsObservable()
|
||||
// .Sum()
|
||||
// .Select(loss => Tuple.Create(predFile, loss)));
|
||||
//}
|
||||
|
||||
|
||||
//try
|
||||
//{
|
||||
// // var labelCount = labelSource.AsObservable().Count().ToEnumerable().First();
|
||||
// var labelCount = 1;
|
||||
// var losses = messages.Aggregate((o1, o2) => o1.Merge(o2));
|
||||
|
||||
// foreach (var msg in losses.ToEnumerable())
|
||||
// Console.WriteLine($"{msg.Item1:-30}: {msg.Item2 / labelCount}");
|
||||
//}
|
||||
//catch (Exception ex)
|
||||
//{
|
||||
// Console.WriteLine();
|
||||
// throw;
|
||||
//}
|
||||
|
||||
// TODO: labels are parsed multiple times
|
||||
var labelSource = ExtractAndCacheLabels(dataFile);
|
||||
Parallel.ForEach(predictions, pred =>
|
||||
{
|
||||
var loss = File.ReadLines(pred)
|
||||
.Select(line => JsonConvert.DeserializeObject<Data>(line))
|
||||
.Zip(labelSource, (data, label) => data.GetLoss(label))
|
||||
.Average();
|
||||
|
||||
Console.WriteLine($"{pred:-30}: {loss}");
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
using System;
|
||||
using Newtonsoft.Json;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
|
@ -25,11 +26,11 @@ namespace Experimentation
|
|||
// internal ActionScore[] Prediction;
|
||||
//}
|
||||
|
||||
public static void Train(string arguments, string inputFile)
|
||||
public static void Train(string arguments, string inputFile, string predictionFile = null, TimeSpan? reloadInterval = null)
|
||||
{
|
||||
using (var reader = new StreamReader(inputFile))
|
||||
using (var prediction = new StreamWriter(inputFile + ".prediction"))
|
||||
using (var vw = new VowpalWabbitJson(new VowpalWabbitSettings(arguments)
|
||||
using (var prediction = new StreamWriter(predictionFile ?? inputFile + ".prediction"))
|
||||
using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments)
|
||||
{
|
||||
Verbose = true
|
||||
}))
|
||||
|
@ -37,13 +38,54 @@ namespace Experimentation
|
|||
string line;
|
||||
int lineNr = 0;
|
||||
int invalidExamples = 0;
|
||||
DateTime? lastTimestamp = null;
|
||||
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore);
|
||||
prediction.WriteLine(lineNr + " " + string.Join(",", pred.Select(a_s => $"{a_s.Action}:{a_s.Score}")));
|
||||
bool reload = false;
|
||||
using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
|
||||
{
|
||||
if (reloadInterval != null)
|
||||
{
|
||||
jsonSerializer.RegisterExtension((state, property) =>
|
||||
{
|
||||
if (property.Equals("_timestamp", StringComparison.Ordinal))
|
||||
{
|
||||
var eventTimestamp = state.Reader.ReadAsDateTime();
|
||||
if (lastTimestamp == null)
|
||||
lastTimestamp = eventTimestamp;
|
||||
else if (lastTimestamp + reloadInterval < eventTimestamp)
|
||||
{
|
||||
reload = true;
|
||||
lastTimestamp = eventTimestamp;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
// var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore);
|
||||
using (var example = jsonSerializer.ParseAndCreate(line))
|
||||
{
|
||||
var pred = example.Learn(VowpalWabbitPredictionType.ActionScore);
|
||||
|
||||
prediction.WriteLine(JsonConvert.SerializeObject(
|
||||
new
|
||||
{
|
||||
nr = lineNr,
|
||||
@as = pred.Select(x => x.Action),
|
||||
p = pred.Select(x => x.Score)
|
||||
}));
|
||||
}
|
||||
|
||||
if (reload)
|
||||
vw.Reload();
|
||||
}
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
using Newtonsoft.Json.Linq;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Dataflow;
|
||||
using VW;
|
||||
using VW.Serializer;
|
||||
|
||||
namespace Experimentation
|
||||
{
|
||||
public static class VowpalWabbitJsonToString
|
||||
{
|
||||
public static void Convert(StreamReader reader, StreamWriter writer)
|
||||
{
|
||||
var line = reader.ReadLine();
|
||||
if (line == null)
|
||||
return;
|
||||
|
||||
var jExample = JObject.Parse(line);
|
||||
var settings = jExample.Properties().Any(p => p.Name == "_multi") ? "--cb_explore_adf" : "--cb_explore";
|
||||
|
||||
int lineNr = 1;
|
||||
using (var vw = new VowpalWabbit(new VowpalWabbitSettings(settings) {
|
||||
EnableStringExampleGeneration = true,
|
||||
EnableStringFloatCompact = true,
|
||||
EnableThreadSafeExamplePooling = true
|
||||
}))
|
||||
{
|
||||
var serializeBlock = new TransformBlock<Tuple<string, int>, string>(l =>
|
||||
{
|
||||
using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
|
||||
using (var example = jsonSerializer.ParseAndCreate(l.Item1))
|
||||
{
|
||||
if (example == null)
|
||||
throw new InvalidDataException($"Invalid example in line {l.Item2}: '{l.Item1}'");
|
||||
|
||||
var str = example.VowpalWabbitString;
|
||||
if (example is VowpalWabbitMultiLineExampleCollection)
|
||||
str += "\n";
|
||||
|
||||
return str;
|
||||
}
|
||||
},
|
||||
new ExecutionDataflowBlockOptions
|
||||
{
|
||||
BoundedCapacity = 1024,
|
||||
MaxDegreeOfParallelism = 8
|
||||
});
|
||||
|
||||
var writeBlock = new ActionBlock<string>(
|
||||
l => writer.WriteLine(l),
|
||||
new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, BoundedCapacity = 128 });
|
||||
serializeBlock.LinkTo(writeBlock, new DataflowLinkOptions { PropagateCompletion = true });
|
||||
|
||||
var input = serializeBlock.AsObserver();
|
||||
|
||||
do
|
||||
{
|
||||
input.OnNext(Tuple.Create(line, lineNr));
|
||||
lineNr++;
|
||||
} while ((line = reader.ReadLine()) != null);
|
||||
|
||||
input.OnCompleted();
|
||||
|
||||
serializeBlock.Completion.Wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,8 +8,14 @@
|
|||
<package id="Microsoft.Tpl.Dataflow" version="4.5.24" targetFramework="net452" />
|
||||
<package id="Newtonsoft.Json" version="8.0.2" targetFramework="net452" />
|
||||
<package id="OptimizedPriorityQueue" version="2.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Core" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Interfaces" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Linq" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.PlatformServices" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Windows.Threading" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Spatial" version="5.6.4" targetFramework="net452" />
|
||||
<package id="VowpalWabbit" version="8.2.0.13" targetFramework="net452" />
|
||||
<package id="VowpalWabbit.JSON" version="8.2.0.13" targetFramework="net452" />
|
||||
<package id="VowpalWabbit" version="8.2.0.16" targetFramework="net452" />
|
||||
<package id="VowpalWabbit.JSON" version="8.2.0.16" targetFramework="net452" />
|
||||
<package id="WindowsAzure.Storage" version="7.2.0" targetFramework="net452" />
|
||||
</packages>
|
|
@ -9,6 +9,10 @@
|
|||
<assemblyIdentity name="Newtonsoft.Json" publicKeyToken="30ad4fe6b2a6aeed" culture="neutral" />
|
||||
<bindingRedirect oldVersion="0.0.0.0-8.0.0.0" newVersion="8.0.0.0" />
|
||||
</dependentAssembly>
|
||||
<dependentAssembly>
|
||||
<assemblyIdentity name="VowpalWabbit.Common" publicKeyToken="a76afd1645210483" culture="neutral" />
|
||||
<bindingRedirect oldVersion="0.0.0.0-8.2.0.14" newVersion="8.2.0.14" />
|
||||
</dependentAssembly>
|
||||
</assemblyBinding>
|
||||
</runtime>
|
||||
</configuration>
|
|
@ -65,6 +65,14 @@
|
|||
</Reference>
|
||||
<Reference Include="System" />
|
||||
<Reference Include="System.Core" />
|
||||
<Reference Include="System.Reactive.Core, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Core.3.0.0\lib\net45\System.Reactive.Core.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Reactive.Interfaces, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reactive.Interfaces.3.0.0\lib\net45\System.Reactive.Interfaces.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="System.Spatial, Version=5.6.4.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Spatial.5.6.4\lib\net40\System.Spatial.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
|
@ -79,20 +87,20 @@
|
|||
<Reference Include="System.Data" />
|
||||
<Reference Include="System.Net.Http" />
|
||||
<Reference Include="System.Xml" />
|
||||
<Reference Include="VowpalWabbit, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.Common, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Common.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.Common, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Common.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.Core, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Core.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.Core, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Core.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.13\lib\net45\VowpalWabbit.JSON.dll</HintPath>
|
||||
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
|
||||
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.16\lib\net45\VowpalWabbit.JSON.dll</HintPath>
|
||||
<Private>True</Private>
|
||||
</Reference>
|
||||
</ItemGroup>
|
||||
|
@ -111,12 +119,12 @@
|
|||
</ProjectReference>
|
||||
</ItemGroup>
|
||||
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
|
||||
<Import Project="..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" />
|
||||
<Import Project="..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" />
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets'))" />
|
||||
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets'))" />
|
||||
</Target>
|
||||
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
|
||||
Other similar extension points exist, see Microsoft.Common.targets.
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
using Experimentation;
|
||||
using Microsoft.WindowsAzure.Storage;
|
||||
using Microsoft.WindowsAzure.Storage.Auth;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Dataflow;
|
||||
using VW;
|
||||
using VW.Serializer;
|
||||
|
||||
namespace ExperimentationConsole
|
||||
{
|
||||
|
@ -14,16 +21,65 @@ namespace ExperimentationConsole
|
|||
{
|
||||
static void Main(string[] args)
|
||||
{
|
||||
var storageAccount = new CloudStorageAccount(new StorageCredentials("", ""), false);
|
||||
|
||||
var outputDirectory = @"c:\temp\abc";
|
||||
var startTimeInclusive = new DateTime(2016, 8, 11, 0, 0, 0);
|
||||
var endTimeExclusive = new DateTime(2016, 8, 14, 0, 0, 0);
|
||||
|
||||
using (var writer = new StreamWriter(Path.Combine(outputDirectory, $"{startTimeInclusive:yyyy-MM-dd_HH}-{endTimeExclusive:yyyy-MM-dd_HH}.json")))
|
||||
try
|
||||
{
|
||||
AzureBlobDownloader.Download(storageAccount, startTimeInclusive, endTimeExclusive, writer, outputDirectory).Wait();
|
||||
var stopwatch = Stopwatch.StartNew();
|
||||
|
||||
var storageAccount = new CloudStorageAccount(new StorageCredentials("storage name", "storage key"), false);
|
||||
|
||||
var outputDirectory = @"c:\temp\";
|
||||
var startTimeInclusive = new DateTime(2016, 8, 11, 0, 0, 0);
|
||||
var endTimeExclusive = new DateTime(2016, 8, 14, 0, 0, 0);
|
||||
var outputFile = Path.Combine(outputDirectory, $"{startTimeInclusive:yyyy-MM-dd_HH}-{endTimeExclusive:yyyy-MM-dd_HH}.json");
|
||||
|
||||
// download and merge blob data
|
||||
using (var writer = new StreamWriter(outputFile))
|
||||
{
|
||||
AzureBlobDownloader.Download(storageAccount, startTimeInclusive, endTimeExclusive, writer, outputDirectory).Wait();
|
||||
}
|
||||
|
||||
// pre-process JSON
|
||||
JsonTransform.TransformIgnoreProperties(outputFile, outputFile + ".small",
|
||||
"Somefeatures");
|
||||
|
||||
outputFile += ".small";
|
||||
// filter broken events for complex
|
||||
JsonTransform.Transform(outputFile, outputFile + ".fixed", (reader, writer) =>
|
||||
{
|
||||
var serializer = JsonSerializer.CreateDefault();
|
||||
var obj = (JObject)serializer.Deserialize(reader);
|
||||
var multi = (JArray)obj.SelectToken("$._multi");
|
||||
if (multi.Count == 10)
|
||||
serializer.Serialize(writer, obj);
|
||||
|
||||
return true;
|
||||
});
|
||||
outputFile += ".fixed";
|
||||
|
||||
using (var reader = new StreamReader(outputFile))
|
||||
using (var writer = new StreamWriter(outputFile + ".vw"))
|
||||
{
|
||||
VowpalWabbitJsonToString.Convert(reader, writer);
|
||||
}
|
||||
|
||||
// VW training
|
||||
OfflineTrainer.Train("--cb_explore_adf --epsilon 0.05 -q AB -q UD",
|
||||
outputFile,
|
||||
predictionFile: outputFile + ".2h.prediction",
|
||||
reloadInterval: TimeSpan.FromHours(2));
|
||||
|
||||
Metrics.Compute(outputFile,
|
||||
outputFile + ".prediction",
|
||||
outputFile + ".2h.prediction");
|
||||
|
||||
Console.WriteLine("\ndone " + stopwatch.Elapsed);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Console.WriteLine($"Exception: {ex.Message}. {ex.StackTrace}");
|
||||
}
|
||||
|
||||
Console.ReadKey();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
<package id="Microsoft.Tpl.Dataflow" version="4.5.24" targetFramework="net452" />
|
||||
<package id="Newtonsoft.Json" version="8.0.2" targetFramework="net452" />
|
||||
<package id="OptimizedPriorityQueue" version="2.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Core" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Reactive.Interfaces" version="3.0.0" targetFramework="net452" />
|
||||
<package id="System.Spatial" version="5.6.4" targetFramework="net452" />
|
||||
<package id="VowpalWabbit" version="8.2.0.13" targetFramework="net452" />
|
||||
<package id="VowpalWabbit.JSON" version="8.2.0.13" targetFramework="net452" />
|
||||
<package id="VowpalWabbit" version="8.2.0.16" targetFramework="net452" />
|
||||
<package id="VowpalWabbit.JSON" version="8.2.0.16" targetFramework="net452" />
|
||||
<package id="WindowsAzure.Storage" version="7.2.0" targetFramework="net452" />
|
||||
</packages>
|
Загрузка…
Ссылка в новой задаче