added metrics
This commit is contained in:
Markus Cozowicz 2016-08-17 15:34:02 -04:00
Родитель 4da2e733d0
Коммит 5069f25dab
11 изменённых файлов: 711 добавлений и 42 удалений

Просмотреть файл

@ -67,6 +67,26 @@
</Reference>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Reactive.Core, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Core.3.0.0\lib\net45\System.Reactive.Core.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Reactive.Interfaces, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Interfaces.3.0.0\lib\net45\System.Reactive.Interfaces.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Reactive.Linq, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Linq.3.0.0\lib\net45\System.Reactive.Linq.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Reactive.PlatformServices, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.PlatformServices.3.0.0\lib\net45\System.Reactive.PlatformServices.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Reactive.Windows.Threading, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Windows.Threading.3.0.0\lib\net45\System.Reactive.Windows.Threading.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Spatial, Version=5.6.4.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35, processorArchitecture=MSIL">
<HintPath>..\packages\System.Spatial.5.6.4\lib\net40\System.Spatial.dll</HintPath>
<Private>True</Private>
@ -81,26 +101,30 @@
<Reference Include="System.Data" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Xml" />
<Reference Include="VowpalWabbit, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.dll</HintPath>
<Reference Include="VowpalWabbit, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.Common, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Common.dll</HintPath>
<Reference Include="VowpalWabbit.Common, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Common.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.Core, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Core.dll</HintPath>
<Reference Include="VowpalWabbit.Core, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Core.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.13\lib\net45\VowpalWabbit.JSON.dll</HintPath>
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.16\lib\net45\VowpalWabbit.JSON.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="WindowsBase" />
</ItemGroup>
<ItemGroup>
<Compile Include="AzureBlobDownloader.cs" />
<Compile Include="FileTransformBlock.cs" />
<Compile Include="VowpalWabbitJsonToString.cs" />
<Compile Include="JsonTransform.cs" />
<Compile Include="Metrics.cs" />
<Compile Include="OfflineTrainer.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
@ -109,12 +133,12 @@
<None Include="packages.config" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<Import Project="..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" />
<Import Project="..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" />
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets'))" />
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets'))" />
</Target>
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.

Просмотреть файл

@ -0,0 +1,114 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace Experimentation
{
public class Line
{
public string Content { get; set; }
public int Number { get; set; }
}
public static class FileTransformBlock
{
public static ISourceBlock<T> Create<T>(string file, Func<Line, T> transform)
{
var block = new TransformBlock<Line, T>(
transform,
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 2,
BoundedCapacity = 128
});
new Thread(() =>
{
var input = block.AsObserver();
try
{
using (var reader = new StreamReader(file))
{
string line;
int lineNr = 0;
var batch = new List<Line>();
while ((line = reader.ReadLine()) != null)
input.OnNext(new Line { Content = line, Number = lineNr++ });
input.OnCompleted();
}
}
catch (Exception ex)
{
input.OnError(ex);
}
}).Start();
return block;
}
public static ISourceBlock<List<T>> CreateBatch<T>(string file, Func<Line, T> transform)
{
var block = new TransformBlock<List<Line>, List<T>>(
batch => batch.Select(transform).ToList(),
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 8,
BoundedCapacity = 128
});
new Thread(() =>
{
var input = block.AsObserver();
try
{
using (var reader = new StreamReader(file))
{
string line;
int lineNr = 0;
var batch = new List<Line>();
while ((line = reader.ReadLine()) != null)
{
batch.Add(new Line { Content = line, Number = lineNr++ });
if (batch.Count >= 512)
{
input.OnNext(batch);
batch = new List<Line>();
}
}
if (batch.Count > 0)
input.OnNext(batch);
input.OnCompleted();
}
}
catch (Exception ex)
{
input.OnError(ex);
}
}).Start();
return block;
// var order = new TransformBlock<List<T>, List<T>>(t => t, new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1 });
// block.LinkTo(order, new DataflowLinkOptions { PropagateCompletion = true });
// return order;
}
public static ISourceBlock<IList<T>> Batch<T>(int n, ISourceBlock<T> source)
{
var block = new BatchBlock<T>(n);
source.LinkTo(block, new DataflowLinkOptions { PropagateCompletion = true });
return block;
}
}
}

Просмотреть файл

@ -11,7 +11,7 @@ namespace Experimentation
{
public static class JsonTransform
{
public static void TransformIngoreProperties(string fileIn, string fileOut, params string[] propertiesToIgnore)
public static void TransformIgnoreProperties(string fileIn, string fileOut, params string[] propertiesToIgnore)
{
var ignorePropertiesSet = new HashSet<string>(propertiesToIgnore);
Transform(fileIn, fileOut, (reader, writer) =>
@ -29,8 +29,8 @@ namespace Experimentation
public static void Transform(string fileIn, string fileOut, Func<JsonTextReader, JsonTextWriter, bool> transform)
{
using (var reader = new StreamReader(fileIn))
using (var writer = new StreamWriter(fileOut))
using (var reader = new StreamReader(fileIn, Encoding.UTF8))
using (var writer = new StreamWriter(fileOut, false, Encoding.UTF8))
{
var transformBlock = new TransformBlock<string, string>(
evt =>
@ -54,7 +54,7 @@ namespace Experimentation
MaxDegreeOfParallelism = 8 // TODO:parameterize
});
var outputBock = new ActionBlock<string>(l => writer.WriteLine(l),
var outputBock = new ActionBlock<string>(l => { if (!string.IsNullOrEmpty(l)) writer.WriteLine(l); },
new ExecutionDataflowBlockOptions { BoundedCapacity = 1024, MaxDegreeOfParallelism = 1 });
transformBlock.LinkTo(outputBock, new DataflowLinkOptions { PropagateCompletion = true });

341
Experimentation/Metrics.cs Normal file
Просмотреть файл

@ -0,0 +1,341 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reactive.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using VW;
namespace Experimentation
{
internal class Label
{
[JsonProperty("as", NullValueHandling = NullValueHandling.Ignore)]
internal int[] Actions;
/// <summary>
/// Ordered for Actions
/// </summary>
[JsonProperty("p", NullValueHandling = NullValueHandling.Ignore)]
internal float[] Probabilities;
[JsonProperty("a", NullValueHandling = NullValueHandling.Ignore)]
internal int? Action;
[JsonProperty("ac", NullValueHandling = NullValueHandling.Ignore)]
internal int? ActionCount;
[JsonProperty("pr", NullValueHandling = NullValueHandling.Ignore)]
internal float? Probability;
/// <summary>
/// Ordred 0...n
/// </summary>
[JsonIgnore]
internal float[] ProbabilitiesOrdered;
[JsonProperty("pd", NullValueHandling = NullValueHandling.Ignore)]
internal float? ProbabilityOfDrop = 0f;
[JsonProperty("c")]
internal float Cost;
[JsonIgnore]
internal int LineNr;
}
public static class Metrics
{
private static IEnumerable<Label> ParseCacheLabels(string labelFile)
{
//return FileTransformBlock.CreateBatch(
// labelFile,
// l =>
// {
// var label = JsonConvert.DeserializeObject<Label>(l.Content);
// if (label.Probabilities == null)
// {
// // reconstruct action/probabilities based on action/probDeprecated
// label.Probabilities = Enumerable.Repeat((float)label.Probability, (int)label.ActionCount).ToArray();
// label.Actions = new int[(int)(int)label.ActionCount];
// label.Actions[0] = (int)label.Action;
// for (int i = 1, j = 1; i < (int)label.ActionCount; i++, j++)
// {
// if (j == (int)label.ActionCount)
// j++;
// label.Actions[i] = j;
// }
// }
// // order probs by action
// label.ProbabilitiesOrdered = new float[label.Actions.Length];
// for (int i = 0; i < label.Actions.Length; i++)
// label.ProbabilitiesOrdered[label.Actions[i] - 1] = label.Probabilities[i];
// label.LineNr = l.Number;
// return label;
// });
using (var reader = new StreamReader(labelFile))
{
string line;
int lineNr = 0;
while ((line = reader.ReadLine()) != null)
{
var label = JsonConvert.DeserializeObject<Label>(line);
if (label.Probabilities == null)
{
// reconstruct action/probabilities based on action/probDeprecated
label.Probabilities = Enumerable.Repeat((float)label.Probability, (int)label.ActionCount).ToArray();
label.Actions = new int[(int)(int)label.ActionCount];
label.Actions[0] = (int)label.Action;
for (int i = 1, j = 1; i < (int)label.ActionCount; i++, j++)
{
if (j == (int)label.ActionCount)
j++;
label.Actions[i] = j;
}
}
// order probs by action
label.ProbabilitiesOrdered = new float[label.Actions.Length];
for (int i = 0; i < label.Actions.Length; i++)
label.ProbabilitiesOrdered[label.Actions[i] - 1] = label.Probabilities[i];
label.LineNr = lineNr++;
yield return label;
}
}
}
private static IEnumerable<Label> ExtractAndCacheLabels(string data)
{
var labelFile = data + ".labels";
if (!File.Exists(labelFile))
{
Console.WriteLine($"Building label cache file {labelFile}");
var jsonSerializer = new JsonSerializer();
using (var writer = new StreamWriter(labelFile))
using (var reader = new StreamReader(data))
{
string line;
int lineNr = 0;
while ((line = reader.ReadLine()) != null)
{
int? action = null;
float? cost = null;
float[] probabilities = null;
int[] actions = null;
float? probabilityOfDrop = null;
float? probDeprecated = null;
int? actionCount = null;
using (var jsonReader = new JsonTextReader(new StringReader(line)))
{
while (jsonReader.Read())
{
if (jsonReader.TokenType == JsonToken.PropertyName)
{
if ("_label_action".Equals((string)jsonReader.Value, StringComparison.Ordinal))
action = jsonReader.ReadAsInt32();
else if ("_label_cost".Equals((string)jsonReader.Value, StringComparison.Ordinal))
cost = (float)jsonReader.ReadAsDouble();
else if ("_label_probability".Equals((string)jsonReader.Value, StringComparison.Ordinal))
probDeprecated = (float)jsonReader.ReadAsDouble();
else if ("_a".Equals((string)jsonReader.Value, StringComparison.Ordinal))
actions = jsonSerializer.Deserialize<int[]>(jsonReader);
else if ("_p".Equals((string)jsonReader.Value, StringComparison.Ordinal))
probabilities = jsonSerializer.Deserialize<float[]>(jsonReader);
else if ("_ProbabilityOfDrop".Equals((string)jsonReader.Value, StringComparison.Ordinal))
probabilityOfDrop = (float)jsonReader.ReadAsDouble();
else if ("_multi".Equals((string)jsonReader.Value, StringComparison.Ordinal))
{
if (!jsonReader.Read() && jsonReader.TokenType == JsonToken.StartArray)
{
throw new InvalidDataException($"Unexpected type for _multi: {jsonReader.TokenType}");
}
// count _multi elements
actionCount = 0;
while (jsonReader.Read() && jsonReader.TokenType == JsonToken.StartObject)
{
jsonReader.Skip();
actionCount++;
}
if (jsonReader.TokenType != JsonToken.EndArray)
{
throw new InvalidDataException($"Unexpected type for _multi: {jsonReader.TokenType}");
}
}
if (actions != null && cost != null && probabilities != null)
break;
}
}
}
if ((actions == null && action == null) || cost == null || (probabilities == null && probDeprecated == null))
throw new InvalidDataException("Missing label in line " + lineNr);
var label = new Label
{
Cost = (float)cost,
ProbabilityOfDrop = probabilityOfDrop
};
if (actions == null || probabilities == null)
{
label.Action = action;
label.Probability = probDeprecated;
label.ActionCount = actionCount;
}
else
{
label.Actions = actions;
label.Probabilities = probabilities;
}
writer.WriteLine(JsonConvert.SerializeObject(label));
lineNr++;
}
}
}
return ParseCacheLabels(labelFile);
}
internal class Data
{
[JsonProperty("nr")]
public int LineNr { get; set; }
[JsonProperty("as")]
public int[] Actions { get; set; }
[JsonProperty("a")]
public int Action { get; set; }
[JsonProperty("p")]
public float[] Probabilities { get; set; }
public float GetLoss(Label label)
{
if (Actions == null)
return VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)Action, label.Cost, label.Probabilities[0]);
else
{
var c = Actions.Zip(Probabilities, (action, prob) => new { Action = action, Prob = prob })
.Sum(ap => ap.Prob * VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)ap.Action + 1, label.Cost, label.Probabilities[0]));
var p = Actions.Zip(Probabilities, (action, prob) => new { Action = action, Prob = prob })
.Sum(ap => ap.Prob / (label.Probabilities[ap.Action] * (1 - label.ProbabilityOfDrop ?? 0)));
// SUM(cost) / SUM(1 / prob)
return c / (1f / p);
}
}
}
public static void Compute(string dataFile, params string[] predictions)
{
// deserialize in parallel
// run all computation in parallel
// average
//var labelSource = ExtractAndCacheLabels(dataFile);
//var messages = new List<IObservable<Tuple<string, float>>>();
//for (int i = 0; i < predictions.Length; i++)
//{
// var predFile = predictions[i];
// var pred = FileTransformBlock.CreateBatch(predFile, line => JsonConvert.DeserializeObject<Data>(line.Content));
// var joinBlock = new JoinBlock<List<Label>, List<Data>>(new GroupingDataflowBlockOptions { Greedy = false, BoundedCapacity = 128 });
// labelSource.LinkTo(joinBlock.Target1, new DataflowLinkOptions { PropagateCompletion = true });
// pred.LinkTo(joinBlock.Target2, new DataflowLinkOptions { PropagateCompletion = true });
// var lossBlock = new TransformBlock<Tuple<List<Label>, List<Data>>, float>(
// b =>
// {
// if (b.Item1.Count != b.Item2.Count)
// throw new InvalidDataException();
// var sum = 0f;
// for (int j = 0; j < b.Item1.Count; j++)
// {
// var label = b.Item1[j];
// var data = b.Item2[j];
// if (label.LineNr != data.LineNr)
// throw new InvalidDataException($"Label line nr {label.LineNr} does not match prediction line number {data.LineNr}");
// if (data.Actions == null)
// sum += VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)data.Action, label.Cost, label.Probabilities[0]);
// else
// {
// var c = data.Actions.Zip(data.Probabilities, (action, prob) => new { Action = action, Prob = prob })
// .Sum(ap => ap.Prob * VowpalWabbitContextualBanditUtil.GetUnbiasedCost((uint)label.Actions[0], (uint)ap.Action + 1, label.Cost, label.Probabilities[0]));
// var p = data.Actions.Zip(data.Probabilities, (action, prob) => new { Action = action, Prob = prob })
// .Sum(ap => ap.Prob / (label.Probabilities[ap.Action] * (1 - label.ProbabilityOfDrop ?? 0)));
// // SUM(cost) / SUM(1 / prob)
// sum += c / (1f / p);
// }
// }
// return sum;
// },
// new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 4 });
// joinBlock.LinkTo(lossBlock, new DataflowLinkOptions { PropagateCompletion = true });
// messages.Add(lossBlock.AsObservable()
// .Sum()
// .Select(loss => Tuple.Create(predFile, loss)));
//}
//try
//{
// // var labelCount = labelSource.AsObservable().Count().ToEnumerable().First();
// var labelCount = 1;
// var losses = messages.Aggregate((o1, o2) => o1.Merge(o2));
// foreach (var msg in losses.ToEnumerable())
// Console.WriteLine($"{msg.Item1:-30}: {msg.Item2 / labelCount}");
//}
//catch (Exception ex)
//{
// Console.WriteLine();
// throw;
//}
// TODO: labels are parsed multiple times
var labelSource = ExtractAndCacheLabels(dataFile);
Parallel.ForEach(predictions, pred =>
{
var loss = File.ReadLines(pred)
.Select(line => JsonConvert.DeserializeObject<Data>(line))
.Zip(labelSource, (data, label) => data.GetLoss(label))
.Average();
Console.WriteLine($"{pred:-30}: {loss}");
});
}
}
}

Просмотреть файл

@ -1,4 +1,5 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
@ -25,11 +26,11 @@ namespace Experimentation
// internal ActionScore[] Prediction;
//}
public static void Train(string arguments, string inputFile)
public static void Train(string arguments, string inputFile, string predictionFile = null, TimeSpan? reloadInterval = null)
{
using (var reader = new StreamReader(inputFile))
using (var prediction = new StreamWriter(inputFile + ".prediction"))
using (var vw = new VowpalWabbitJson(new VowpalWabbitSettings(arguments)
using (var prediction = new StreamWriter(predictionFile ?? inputFile + ".prediction"))
using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments)
{
Verbose = true
}))
@ -37,13 +38,54 @@ namespace Experimentation
string line;
int lineNr = 0;
int invalidExamples = 0;
DateTime? lastTimestamp = null;
while ((line = reader.ReadLine()) != null)
{
try
{
var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore);
prediction.WriteLine(lineNr + " " + string.Join(",", pred.Select(a_s => $"{a_s.Action}:{a_s.Score}")));
bool reload = false;
using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
{
if (reloadInterval != null)
{
jsonSerializer.RegisterExtension((state, property) =>
{
if (property.Equals("_timestamp", StringComparison.Ordinal))
{
var eventTimestamp = state.Reader.ReadAsDateTime();
if (lastTimestamp == null)
lastTimestamp = eventTimestamp;
else if (lastTimestamp + reloadInterval < eventTimestamp)
{
reload = true;
lastTimestamp = eventTimestamp;
}
return true;
}
return false;
});
}
// var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore);
using (var example = jsonSerializer.ParseAndCreate(line))
{
var pred = example.Learn(VowpalWabbitPredictionType.ActionScore);
prediction.WriteLine(JsonConvert.SerializeObject(
new
{
nr = lineNr,
@as = pred.Select(x => x.Action),
p = pred.Select(x => x.Score)
}));
}
if (reload)
vw.Reload();
}
}
catch (Exception)
{

Просмотреть файл

@ -0,0 +1,72 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using VW;
using VW.Serializer;
namespace Experimentation
{
public static class VowpalWabbitJsonToString
{
public static void Convert(StreamReader reader, StreamWriter writer)
{
var line = reader.ReadLine();
if (line == null)
return;
var jExample = JObject.Parse(line);
var settings = jExample.Properties().Any(p => p.Name == "_multi") ? "--cb_explore_adf" : "--cb_explore";
int lineNr = 1;
using (var vw = new VowpalWabbit(new VowpalWabbitSettings(settings) {
EnableStringExampleGeneration = true,
EnableStringFloatCompact = true,
EnableThreadSafeExamplePooling = true
}))
{
var serializeBlock = new TransformBlock<Tuple<string, int>, string>(l =>
{
using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
using (var example = jsonSerializer.ParseAndCreate(l.Item1))
{
if (example == null)
throw new InvalidDataException($"Invalid example in line {l.Item2}: '{l.Item1}'");
var str = example.VowpalWabbitString;
if (example is VowpalWabbitMultiLineExampleCollection)
str += "\n";
return str;
}
},
new ExecutionDataflowBlockOptions
{
BoundedCapacity = 1024,
MaxDegreeOfParallelism = 8
});
var writeBlock = new ActionBlock<string>(
l => writer.WriteLine(l),
new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, BoundedCapacity = 128 });
serializeBlock.LinkTo(writeBlock, new DataflowLinkOptions { PropagateCompletion = true });
var input = serializeBlock.AsObserver();
do
{
input.OnNext(Tuple.Create(line, lineNr));
lineNr++;
} while ((line = reader.ReadLine()) != null);
input.OnCompleted();
serializeBlock.Completion.Wait();
}
}
}
}

Просмотреть файл

@ -8,8 +8,14 @@
<package id="Microsoft.Tpl.Dataflow" version="4.5.24" targetFramework="net452" />
<package id="Newtonsoft.Json" version="8.0.2" targetFramework="net452" />
<package id="OptimizedPriorityQueue" version="2.0.0" targetFramework="net452" />
<package id="System.Reactive" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.Core" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.Interfaces" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.Linq" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.PlatformServices" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.Windows.Threading" version="3.0.0" targetFramework="net452" />
<package id="System.Spatial" version="5.6.4" targetFramework="net452" />
<package id="VowpalWabbit" version="8.2.0.13" targetFramework="net452" />
<package id="VowpalWabbit.JSON" version="8.2.0.13" targetFramework="net452" />
<package id="VowpalWabbit" version="8.2.0.16" targetFramework="net452" />
<package id="VowpalWabbit.JSON" version="8.2.0.16" targetFramework="net452" />
<package id="WindowsAzure.Storage" version="7.2.0" targetFramework="net452" />
</packages>

Просмотреть файл

@ -9,6 +9,10 @@
<assemblyIdentity name="Newtonsoft.Json" publicKeyToken="30ad4fe6b2a6aeed" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-8.0.0.0" newVersion="8.0.0.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="VowpalWabbit.Common" publicKeyToken="a76afd1645210483" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-8.2.0.14" newVersion="8.2.0.14" />
</dependentAssembly>
</assemblyBinding>
</runtime>
</configuration>

Просмотреть файл

@ -65,6 +65,14 @@
</Reference>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Reactive.Core, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Core.3.0.0\lib\net45\System.Reactive.Core.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Reactive.Interfaces, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reactive.Interfaces.3.0.0\lib\net45\System.Reactive.Interfaces.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="System.Spatial, Version=5.6.4.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35, processorArchitecture=MSIL">
<HintPath>..\packages\System.Spatial.5.6.4\lib\net40\System.Spatial.dll</HintPath>
<Private>True</Private>
@ -79,20 +87,20 @@
<Reference Include="System.Data" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Xml" />
<Reference Include="VowpalWabbit, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.dll</HintPath>
<Reference Include="VowpalWabbit, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.Common, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Common.dll</HintPath>
<Reference Include="VowpalWabbit.Common, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Common.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.Core, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.13\lib\net45\VowpalWabbit.Core.dll</HintPath>
<Reference Include="VowpalWabbit.Core, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.8.2.0.16\lib\net45\VowpalWabbit.Core.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.13, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.13\lib\net45\VowpalWabbit.JSON.dll</HintPath>
<Reference Include="VowpalWabbit.JSON, Version=8.2.0.16, Culture=neutral, PublicKeyToken=a76afd1645210483, processorArchitecture=AMD64">
<HintPath>..\packages\VowpalWabbit.JSON.8.2.0.16\lib\net45\VowpalWabbit.JSON.dll</HintPath>
<Private>True</Private>
</Reference>
</ItemGroup>
@ -111,12 +119,12 @@
</ProjectReference>
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<Import Project="..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" />
<Import Project="..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets" Condition="Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" />
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.13\build\VowpalWabbit.targets'))" />
<Error Condition="!Exists('..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\packages\VowpalWabbit.8.2.0.16\build\VowpalWabbit.targets'))" />
</Target>
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.

Просмотреть файл

@ -1,12 +1,19 @@
using Experimentation;
using Microsoft.WindowsAzure.Storage;
using Microsoft.WindowsAzure.Storage.Auth;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using VW;
using VW.Serializer;
namespace ExperimentationConsole
{
@ -14,16 +21,65 @@ namespace ExperimentationConsole
{
static void Main(string[] args)
{
var storageAccount = new CloudStorageAccount(new StorageCredentials("", ""), false);
var outputDirectory = @"c:\temp\abc";
var startTimeInclusive = new DateTime(2016, 8, 11, 0, 0, 0);
var endTimeExclusive = new DateTime(2016, 8, 14, 0, 0, 0);
using (var writer = new StreamWriter(Path.Combine(outputDirectory, $"{startTimeInclusive:yyyy-MM-dd_HH}-{endTimeExclusive:yyyy-MM-dd_HH}.json")))
try
{
AzureBlobDownloader.Download(storageAccount, startTimeInclusive, endTimeExclusive, writer, outputDirectory).Wait();
var stopwatch = Stopwatch.StartNew();
var storageAccount = new CloudStorageAccount(new StorageCredentials("storage name", "storage key"), false);
var outputDirectory = @"c:\temp\";
var startTimeInclusive = new DateTime(2016, 8, 11, 0, 0, 0);
var endTimeExclusive = new DateTime(2016, 8, 14, 0, 0, 0);
var outputFile = Path.Combine(outputDirectory, $"{startTimeInclusive:yyyy-MM-dd_HH}-{endTimeExclusive:yyyy-MM-dd_HH}.json");
// download and merge blob data
using (var writer = new StreamWriter(outputFile))
{
AzureBlobDownloader.Download(storageAccount, startTimeInclusive, endTimeExclusive, writer, outputDirectory).Wait();
}
// pre-process JSON
JsonTransform.TransformIgnoreProperties(outputFile, outputFile + ".small",
"Somefeatures");
outputFile += ".small";
// filter broken events for complex
JsonTransform.Transform(outputFile, outputFile + ".fixed", (reader, writer) =>
{
var serializer = JsonSerializer.CreateDefault();
var obj = (JObject)serializer.Deserialize(reader);
var multi = (JArray)obj.SelectToken("$._multi");
if (multi.Count == 10)
serializer.Serialize(writer, obj);
return true;
});
outputFile += ".fixed";
using (var reader = new StreamReader(outputFile))
using (var writer = new StreamWriter(outputFile + ".vw"))
{
VowpalWabbitJsonToString.Convert(reader, writer);
}
// VW training
OfflineTrainer.Train("--cb_explore_adf --epsilon 0.05 -q AB -q UD",
outputFile,
predictionFile: outputFile + ".2h.prediction",
reloadInterval: TimeSpan.FromHours(2));
Metrics.Compute(outputFile,
outputFile + ".prediction",
outputFile + ".2h.prediction");
Console.WriteLine("\ndone " + stopwatch.Elapsed);
}
catch (Exception ex)
{
Console.WriteLine($"Exception: {ex.Message}. {ex.StackTrace}");
}
Console.ReadKey();
}
}
}

Просмотреть файл

@ -7,8 +7,10 @@
<package id="Microsoft.Tpl.Dataflow" version="4.5.24" targetFramework="net452" />
<package id="Newtonsoft.Json" version="8.0.2" targetFramework="net452" />
<package id="OptimizedPriorityQueue" version="2.0.0" targetFramework="net452" />
<package id="System.Reactive.Core" version="3.0.0" targetFramework="net452" />
<package id="System.Reactive.Interfaces" version="3.0.0" targetFramework="net452" />
<package id="System.Spatial" version="5.6.4" targetFramework="net452" />
<package id="VowpalWabbit" version="8.2.0.13" targetFramework="net452" />
<package id="VowpalWabbit.JSON" version="8.2.0.13" targetFramework="net452" />
<package id="VowpalWabbit" version="8.2.0.16" targetFramework="net452" />
<package id="VowpalWabbit.JSON" version="8.2.0.16" targetFramework="net452" />
<package id="WindowsAzure.Storage" version="7.2.0" targetFramework="net452" />
</packages>