This commit is contained in:
Alex Polozov 2017-05-04 12:16:52 -07:00
Родитель dbcb8592f8
Коммит d4ed04d875
4 изменённых файлов: 162 добавлений и 39 удалений

119
ProseTutorial/DSLSession.cs Normal file
Просмотреть файл

@ -0,0 +1,119 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ProgramSynthesis;
using Microsoft.ProgramSynthesis.AST;
using Microsoft.ProgramSynthesis.Extraction.Text.Semantics;
using Microsoft.ProgramSynthesis.Learning;
using Microsoft.ProgramSynthesis.Learning.Logging;
using Microsoft.ProgramSynthesis.Learning.Strategies;
using Microsoft.ProgramSynthesis.Specifications;
using Microsoft.ProgramSynthesis.VersionSpace;
using Microsoft.ProgramSynthesis.Wrangling.Constraints;
using Microsoft.ProgramSynthesis.Wrangling.Logging;
using Microsoft.ProgramSynthesis.Wrangling.Session;
using Newtonsoft.Json;
namespace ProseTutorial {
public class DSLProgram : Program<StringRegion, StringRegion> {
public DSLProgram(ProgramNode programNode, double score) : base(programNode, score) { }
public override StringRegion Run(StringRegion input) => ProgramNode.Invoke(
State.Create(Substrings.Language.Grammar.InputSymbol, input)) as StringRegion;
}
public class DSLProgramSetWrapper : Session<DSLProgram, StringRegion, StringRegion>.IProgramSetWrapper {
public DSLProgramSetWrapper(ProgramSet programSet, Feature<double> score) {
ProgramSet = programSet;
Score = score;
}
public IEnumerable<DSLProgram> RealizedPrograms => ProgramSet.RealizedPrograms.Select(
p => new DSLProgram(
p, p.GetFeatureValue(Score)));
public ProgramSet ProgramSet { get; }
public Feature<double> Score { get; }
}
public class DSLSession : Session<DSLProgram, StringRegion, StringRegion> {
private List<DSLProgram> _lastTopK;
private DSLProgramSetWrapper _lastSet;
public static Grammar Grammar => Substrings.Language.Grammar;
public Feature<double> RankingScore { get; } = new Substrings.RankingScore(Grammar);
public DomainLearningLogic LearningLogic { get; } = new Substrings.WitnessFunctions(Grammar);
public DSLSession(IJournalStorage journalStorage = null,
CultureInfo culture = null,
ILogger logger = null) : base(
journalStorage, culture ?? CultureInfo.InvariantCulture, logger) { }
private SynthesisEngine CreateEngine(bool log = false) =>
new SynthesisEngine(Grammar,
new SynthesisEngine.Config {
UseThreads = false,
Strategies = new ISynthesisStrategy[] {
new EnumerativeSynthesis(),
new DeductiveSynthesis(LearningLogic),
},
LogListener = log ? new LogListener() : null,
});
private static ExampleSpec CreateSpec(LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request) {
var examples = request.Constraints.OfType<Example<StringRegion, StringRegion>>();
var spec = new ExampleSpec(examples.ToDictionary(e => State.Create(Grammar.InputSymbol, e.Input),
e => (object) e.Output));
return spec;
}
protected override IReadOnlyList<DSLProgram> LearnTopKCached(
LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request, RankingMode rankingMode, int k,
CancellationToken cancel) {
var engine = CreateEngine();
var spec = CreateSpec(request);
var learned = engine.LearnGrammarTopK(spec, RankingScore, k, cancel);
_lastTopK = learned.Select(p => new DSLProgram(p, p.GetFeatureValue(RankingScore))).ToList();
return _lastTopK;
}
protected override IProgramSetWrapper LearnAllCached(
LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request, CancellationToken cancel) {
var engine = CreateEngine();
var spec = CreateSpec(request);
var learned = engine.LearnGrammar(spec, cancel);
_lastSet = new DSLProgramSetWrapper(learned, RankingScore);
return _lastSet;
}
public override Task<IReadOnlyList<IQuestion>> GetTopKQuestionsAsync(
int? k = null, double? confidenceThreshold = null, IEnumerable<Type> allowedTypes = null,
CancellationToken cancel = new CancellationToken()) {
throw new NotImplementedException();
}
public override Task<IReadOnlyList<SignificantInputCluster<StringRegion>>> GetSignificantInputClustersAsync(
double? confidenceThreshold = null,
CancellationToken cancel = new CancellationToken()) {
throw new NotImplementedException();
}
public override Task<IReadOnlyList<SignificantInput<StringRegion>>> GetSignificantInputsAsync(
double? confidenceThreshold = null, CancellationToken cancel = new CancellationToken()) {
throw new NotImplementedException();
}
public override Task<IReadOnlyList<StringRegion>> ComputeTopKOutputsAsync(
StringRegion input, int k, RankingMode rankingMode = null,
double? confidenceThreshold = null, CancellationToken cancel = new CancellationToken()) {
throw new NotImplementedException();
}
protected override JsonSerializerSettings JsonSerializerSettingsInstance { get; } =
new JsonSerializerSettings();
protected override string LoggingTypeName => "Sample";
}
}

Просмотреть файл

@ -9,24 +9,23 @@ using Microsoft.ProgramSynthesis.Extraction.Text.Semantics;
using Microsoft.ProgramSynthesis.Specifications;
using Microsoft.ProgramSynthesis.Specifications.Extensions;
using Microsoft.ProgramSynthesis.Utils;
using Microsoft.ProgramSynthesis.VersionSpace;
using Microsoft.ProgramSynthesis.Wrangling.Constraints;
using Microsoft.ProgramSynthesis.Wrangling.Session;
using static ProseTutorial.Utils;
namespace ProseTutorial
{
internal static class Program
{
private static void Main(string[] args)
{
//LoadAndTestSubstrings();
LoadAndTestTextExtraction();
namespace ProseTutorial {
internal static partial class Program {
private static void Main(string[] args) {
LoadAndTestSubstrings();
//LoadAndTestTextExtraction();
}
private static void LoadAndTestSubstrings()
{
var grammar = ProseTutorial.Substrings.Language.Grammar;
private static void LoadAndTestSubstrings() {
var grammar = Substrings.Language.Grammar;
if (grammar == null) return;
ProgramNode p = ProgramNode.Parse(@"SubStr(v, PosPair(AbsPos(v, -4), AbsPos(v, -1)))",
/*ProgramNode p = ProgramNode.Parse(@"SubStr(v, PosPair(AbsPos(v, -4), AbsPos(v, -1)))",
grammar, ASTSerializationFormat.HumanReadable);
StringRegion data = RegionSession.CreateStringRegion("Microsoft PROSE SDK");
State input = State.Create(grammar.InputSymbol, data);
@ -35,48 +34,47 @@ namespace ProseTutorial
StringRegion sdk = data.Slice(data.End - 3, data.End);
Spec spec = ShouldConvert.Given(grammar).To(data, sdk);
Learn(grammar, spec,
new Substrings.RankingScore(grammar), new Substrings.WitnessFunctions(grammar));
new Substrings.RankingScore(grammar), new Substrings.WitnessFunctions(grammar));*/
TestFlashFillBenchmark(grammar, "emails");
TestFlashFillBenchmark("emails");
}
private static void TestFlashFillBenchmark(Grammar grammar, string benchmark, int exampleCount = 2)
{
private static void TestFlashFillBenchmark(string benchmark, int exampleCount = 2) {
string[] lines = File.ReadAllLines($"benchmarks/{benchmark}.tsv");
Tuple<string, string>[] data = lines.Select(l =>
{
(string, string)[] data = lines.Select(l => {
var parts = l.Split(new[] { "\t" }, StringSplitOptions.RemoveEmptyEntries);
return Tuple.Create(parts[0], parts[1]);
return (parts[0], parts[1]);
}).ToArray();
var examples =
data.Take(exampleCount)
.ToDictionary(
t => State.Create(grammar.InputSymbol, RegionSession.CreateStringRegion(t.Item1)),
t => (object)RegionSession.CreateStringRegion(t.Item2));
var spec = new ExampleSpec(examples);
ProgramNode program = Learn(grammar, spec,
new Substrings.RankingScore(grammar),
new Substrings.WitnessFunctions(grammar));
foreach (Tuple<string, string> row in data.Skip(exampleCount))
{
State input = State.Create(grammar.InputSymbol,
RegionSession.CreateStringRegion(row.Item1));
var output = program.Invoke(input);
WriteColored(ConsoleColor.DarkCyan, $"{row.Item1} => {output}");
using (var session = new DSLSession()) {
var examples =
data.Take(exampleCount)
.Select(t => new Example<StringRegion, StringRegion>(
RegionSession.CreateStringRegion(t.Item1),
RegionSession.CreateStringRegion(t.Item2)));
session.AddConstraints(examples);
var program = session.LearnTopK(1)[0];
/*var spec = new ExampleSpec(examples);
ProgramNode program = Learn(grammar, spec,
new Substrings.RankingScore(grammar),
new Substrings.WitnessFunctions(grammar));*/
foreach ((string, string) row in data.Skip(exampleCount)) {
var output = program.Run(RegionSession.CreateStringRegion(row.Item1));
WriteColored(ConsoleColor.DarkCyan, $"{row.Item1} => {output}");
}
}
}
private static void LoadAndTestTextExtraction()
{
var grammar = ProseTutorial.TextExtraction.Language.Grammar;
private static void LoadAndTestTextExtraction() {
var grammar = TextExtraction.Language.Grammar;
if (grammar == null) return;
TestTextExtractionBenchmark(grammar, "countries");
TestTextExtractionBenchmark(grammar, "popl13-erc");
}
private static void TestTextExtractionBenchmark(Grammar grammar, string benchmark)
{
private static void TestTextExtractionBenchmark(Grammar grammar, string benchmark) {
StringRegion document;
List<StringRegion> examples = LoadBenchmark($"benchmarks/{benchmark}.txt", out document);
var input = State.Create(grammar.InputSymbol, document);

Просмотреть файл

@ -186,9 +186,13 @@
<Reference Include="System.Interactive, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
<HintPath>..\packages\System.Interactive.3.0.0\lib\net45\System.Interactive.dll</HintPath>
</Reference>
<Reference Include="System.Numerics" />
<Reference Include="System.Reflection.Metadata, Version=1.3.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<HintPath>..\packages\System.Reflection.Metadata.1.3.0\lib\portable-net45+win8\System.Reflection.Metadata.dll</HintPath>
</Reference>
<Reference Include="System.ValueTuple, Version=4.0.1.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">
<HintPath>..\packages\System.ValueTuple.4.3.0\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
</Reference>
<Reference Include="System.Xml.Linq" />
<Reference Include="System.Data.DataSetExtensions" />
<Reference Include="Microsoft.CSharp" />
@ -197,6 +201,7 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="DSLSession.cs" />
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Utils.cs" />

Просмотреть файл

@ -24,4 +24,5 @@
<package id="System.Runtime" version="4.3.0" targetFramework="net45" />
<package id="System.Runtime.Extensions" version="4.3.0" targetFramework="net45" />
<package id="System.Threading" version="4.3.0" targetFramework="net45" />
<package id="System.ValueTuple" version="4.3.0" targetFramework="net45" />
</packages>