#10: Session
This commit is contained in:
Родитель
dbcb8592f8
Коммит
d4ed04d875
|
@ -0,0 +1,119 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Globalization;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.ProgramSynthesis;
|
||||
using Microsoft.ProgramSynthesis.AST;
|
||||
using Microsoft.ProgramSynthesis.Extraction.Text.Semantics;
|
||||
using Microsoft.ProgramSynthesis.Learning;
|
||||
using Microsoft.ProgramSynthesis.Learning.Logging;
|
||||
using Microsoft.ProgramSynthesis.Learning.Strategies;
|
||||
using Microsoft.ProgramSynthesis.Specifications;
|
||||
using Microsoft.ProgramSynthesis.VersionSpace;
|
||||
using Microsoft.ProgramSynthesis.Wrangling.Constraints;
|
||||
using Microsoft.ProgramSynthesis.Wrangling.Logging;
|
||||
using Microsoft.ProgramSynthesis.Wrangling.Session;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace ProseTutorial {
|
||||
public class DSLProgram : Program<StringRegion, StringRegion> {
|
||||
public DSLProgram(ProgramNode programNode, double score) : base(programNode, score) { }
|
||||
|
||||
public override StringRegion Run(StringRegion input) => ProgramNode.Invoke(
|
||||
State.Create(Substrings.Language.Grammar.InputSymbol, input)) as StringRegion;
|
||||
}
|
||||
|
||||
public class DSLProgramSetWrapper : Session<DSLProgram, StringRegion, StringRegion>.IProgramSetWrapper {
|
||||
public DSLProgramSetWrapper(ProgramSet programSet, Feature<double> score) {
|
||||
ProgramSet = programSet;
|
||||
Score = score;
|
||||
}
|
||||
|
||||
public IEnumerable<DSLProgram> RealizedPrograms => ProgramSet.RealizedPrograms.Select(
|
||||
p => new DSLProgram(
|
||||
p, p.GetFeatureValue(Score)));
|
||||
|
||||
public ProgramSet ProgramSet { get; }
|
||||
public Feature<double> Score { get; }
|
||||
}
|
||||
|
||||
public class DSLSession : Session<DSLProgram, StringRegion, StringRegion> {
|
||||
private List<DSLProgram> _lastTopK;
|
||||
private DSLProgramSetWrapper _lastSet;
|
||||
public static Grammar Grammar => Substrings.Language.Grammar;
|
||||
public Feature<double> RankingScore { get; } = new Substrings.RankingScore(Grammar);
|
||||
public DomainLearningLogic LearningLogic { get; } = new Substrings.WitnessFunctions(Grammar);
|
||||
|
||||
public DSLSession(IJournalStorage journalStorage = null,
|
||||
CultureInfo culture = null,
|
||||
ILogger logger = null) : base(
|
||||
journalStorage, culture ?? CultureInfo.InvariantCulture, logger) { }
|
||||
|
||||
private SynthesisEngine CreateEngine(bool log = false) =>
|
||||
new SynthesisEngine(Grammar,
|
||||
new SynthesisEngine.Config {
|
||||
UseThreads = false,
|
||||
Strategies = new ISynthesisStrategy[] {
|
||||
new EnumerativeSynthesis(),
|
||||
new DeductiveSynthesis(LearningLogic),
|
||||
},
|
||||
LogListener = log ? new LogListener() : null,
|
||||
});
|
||||
|
||||
private static ExampleSpec CreateSpec(LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request) {
|
||||
var examples = request.Constraints.OfType<Example<StringRegion, StringRegion>>();
|
||||
var spec = new ExampleSpec(examples.ToDictionary(e => State.Create(Grammar.InputSymbol, e.Input),
|
||||
e => (object) e.Output));
|
||||
return spec;
|
||||
}
|
||||
|
||||
protected override IReadOnlyList<DSLProgram> LearnTopKCached(
|
||||
LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request, RankingMode rankingMode, int k,
|
||||
CancellationToken cancel) {
|
||||
var engine = CreateEngine();
|
||||
var spec = CreateSpec(request);
|
||||
var learned = engine.LearnGrammarTopK(spec, RankingScore, k, cancel);
|
||||
_lastTopK = learned.Select(p => new DSLProgram(p, p.GetFeatureValue(RankingScore))).ToList();
|
||||
return _lastTopK;
|
||||
}
|
||||
|
||||
protected override IProgramSetWrapper LearnAllCached(
|
||||
LearnProgramRequest<DSLProgram, StringRegion, StringRegion> request, CancellationToken cancel) {
|
||||
var engine = CreateEngine();
|
||||
var spec = CreateSpec(request);
|
||||
var learned = engine.LearnGrammar(spec, cancel);
|
||||
_lastSet = new DSLProgramSetWrapper(learned, RankingScore);
|
||||
return _lastSet;
|
||||
}
|
||||
|
||||
public override Task<IReadOnlyList<IQuestion>> GetTopKQuestionsAsync(
|
||||
int? k = null, double? confidenceThreshold = null, IEnumerable<Type> allowedTypes = null,
|
||||
CancellationToken cancel = new CancellationToken()) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override Task<IReadOnlyList<SignificantInputCluster<StringRegion>>> GetSignificantInputClustersAsync(
|
||||
double? confidenceThreshold = null,
|
||||
CancellationToken cancel = new CancellationToken()) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override Task<IReadOnlyList<SignificantInput<StringRegion>>> GetSignificantInputsAsync(
|
||||
double? confidenceThreshold = null, CancellationToken cancel = new CancellationToken()) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override Task<IReadOnlyList<StringRegion>> ComputeTopKOutputsAsync(
|
||||
StringRegion input, int k, RankingMode rankingMode = null,
|
||||
double? confidenceThreshold = null, CancellationToken cancel = new CancellationToken()) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
protected override JsonSerializerSettings JsonSerializerSettingsInstance { get; } =
|
||||
new JsonSerializerSettings();
|
||||
|
||||
protected override string LoggingTypeName => "Sample";
|
||||
}
|
||||
}
|
|
@ -9,24 +9,23 @@ using Microsoft.ProgramSynthesis.Extraction.Text.Semantics;
|
|||
using Microsoft.ProgramSynthesis.Specifications;
|
||||
using Microsoft.ProgramSynthesis.Specifications.Extensions;
|
||||
using Microsoft.ProgramSynthesis.Utils;
|
||||
using Microsoft.ProgramSynthesis.VersionSpace;
|
||||
using Microsoft.ProgramSynthesis.Wrangling.Constraints;
|
||||
using Microsoft.ProgramSynthesis.Wrangling.Session;
|
||||
using static ProseTutorial.Utils;
|
||||
|
||||
namespace ProseTutorial
|
||||
{
|
||||
internal static class Program
|
||||
{
|
||||
private static void Main(string[] args)
|
||||
{
|
||||
//LoadAndTestSubstrings();
|
||||
LoadAndTestTextExtraction();
|
||||
namespace ProseTutorial {
|
||||
internal static partial class Program {
|
||||
private static void Main(string[] args) {
|
||||
LoadAndTestSubstrings();
|
||||
//LoadAndTestTextExtraction();
|
||||
}
|
||||
|
||||
private static void LoadAndTestSubstrings()
|
||||
{
|
||||
var grammar = ProseTutorial.Substrings.Language.Grammar;
|
||||
private static void LoadAndTestSubstrings() {
|
||||
var grammar = Substrings.Language.Grammar;
|
||||
if (grammar == null) return;
|
||||
|
||||
ProgramNode p = ProgramNode.Parse(@"SubStr(v, PosPair(AbsPos(v, -4), AbsPos(v, -1)))",
|
||||
/*ProgramNode p = ProgramNode.Parse(@"SubStr(v, PosPair(AbsPos(v, -4), AbsPos(v, -1)))",
|
||||
grammar, ASTSerializationFormat.HumanReadable);
|
||||
StringRegion data = RegionSession.CreateStringRegion("Microsoft PROSE SDK");
|
||||
State input = State.Create(grammar.InputSymbol, data);
|
||||
|
@ -35,48 +34,47 @@ namespace ProseTutorial
|
|||
StringRegion sdk = data.Slice(data.End - 3, data.End);
|
||||
Spec spec = ShouldConvert.Given(grammar).To(data, sdk);
|
||||
Learn(grammar, spec,
|
||||
new Substrings.RankingScore(grammar), new Substrings.WitnessFunctions(grammar));
|
||||
new Substrings.RankingScore(grammar), new Substrings.WitnessFunctions(grammar));*/
|
||||
|
||||
TestFlashFillBenchmark(grammar, "emails");
|
||||
TestFlashFillBenchmark("emails");
|
||||
}
|
||||
|
||||
private static void TestFlashFillBenchmark(Grammar grammar, string benchmark, int exampleCount = 2)
|
||||
{
|
||||
private static void TestFlashFillBenchmark(string benchmark, int exampleCount = 2) {
|
||||
string[] lines = File.ReadAllLines($"benchmarks/{benchmark}.tsv");
|
||||
Tuple<string, string>[] data = lines.Select(l =>
|
||||
{
|
||||
(string, string)[] data = lines.Select(l => {
|
||||
var parts = l.Split(new[] { "\t" }, StringSplitOptions.RemoveEmptyEntries);
|
||||
return Tuple.Create(parts[0], parts[1]);
|
||||
return (parts[0], parts[1]);
|
||||
}).ToArray();
|
||||
var examples =
|
||||
data.Take(exampleCount)
|
||||
.ToDictionary(
|
||||
t => State.Create(grammar.InputSymbol, RegionSession.CreateStringRegion(t.Item1)),
|
||||
t => (object)RegionSession.CreateStringRegion(t.Item2));
|
||||
var spec = new ExampleSpec(examples);
|
||||
ProgramNode program = Learn(grammar, spec,
|
||||
new Substrings.RankingScore(grammar),
|
||||
new Substrings.WitnessFunctions(grammar));
|
||||
foreach (Tuple<string, string> row in data.Skip(exampleCount))
|
||||
{
|
||||
State input = State.Create(grammar.InputSymbol,
|
||||
RegionSession.CreateStringRegion(row.Item1));
|
||||
var output = program.Invoke(input);
|
||||
WriteColored(ConsoleColor.DarkCyan, $"{row.Item1} => {output}");
|
||||
|
||||
using (var session = new DSLSession()) {
|
||||
var examples =
|
||||
data.Take(exampleCount)
|
||||
.Select(t => new Example<StringRegion, StringRegion>(
|
||||
RegionSession.CreateStringRegion(t.Item1),
|
||||
RegionSession.CreateStringRegion(t.Item2)));
|
||||
session.AddConstraints(examples);
|
||||
var program = session.LearnTopK(1)[0];
|
||||
|
||||
/*var spec = new ExampleSpec(examples);
|
||||
ProgramNode program = Learn(grammar, spec,
|
||||
new Substrings.RankingScore(grammar),
|
||||
new Substrings.WitnessFunctions(grammar));*/
|
||||
foreach ((string, string) row in data.Skip(exampleCount)) {
|
||||
var output = program.Run(RegionSession.CreateStringRegion(row.Item1));
|
||||
WriteColored(ConsoleColor.DarkCyan, $"{row.Item1} => {output}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void LoadAndTestTextExtraction()
|
||||
{
|
||||
var grammar = ProseTutorial.TextExtraction.Language.Grammar;
|
||||
private static void LoadAndTestTextExtraction() {
|
||||
var grammar = TextExtraction.Language.Grammar;
|
||||
if (grammar == null) return;
|
||||
|
||||
TestTextExtractionBenchmark(grammar, "countries");
|
||||
TestTextExtractionBenchmark(grammar, "popl13-erc");
|
||||
}
|
||||
|
||||
private static void TestTextExtractionBenchmark(Grammar grammar, string benchmark)
|
||||
{
|
||||
private static void TestTextExtractionBenchmark(Grammar grammar, string benchmark) {
|
||||
StringRegion document;
|
||||
List<StringRegion> examples = LoadBenchmark($"benchmarks/{benchmark}.txt", out document);
|
||||
var input = State.Create(grammar.InputSymbol, document);
|
||||
|
|
|
@ -186,9 +186,13 @@
|
|||
<Reference Include="System.Interactive, Version=3.0.0.0, Culture=neutral, PublicKeyToken=94bc3704cddfc263, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Interactive.3.0.0\lib\net45\System.Interactive.dll</HintPath>
|
||||
</Reference>
|
||||
<Reference Include="System.Numerics" />
|
||||
<Reference Include="System.Reflection.Metadata, Version=1.3.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.Reflection.Metadata.1.3.0\lib\portable-net45+win8\System.Reflection.Metadata.dll</HintPath>
|
||||
</Reference>
|
||||
<Reference Include="System.ValueTuple, Version=4.0.1.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">
|
||||
<HintPath>..\packages\System.ValueTuple.4.3.0\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
|
||||
</Reference>
|
||||
<Reference Include="System.Xml.Linq" />
|
||||
<Reference Include="System.Data.DataSetExtensions" />
|
||||
<Reference Include="Microsoft.CSharp" />
|
||||
|
@ -197,6 +201,7 @@
|
|||
<Reference Include="System.Xml" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Compile Include="DSLSession.cs" />
|
||||
<Compile Include="Program.cs" />
|
||||
<Compile Include="Properties\AssemblyInfo.cs" />
|
||||
<Compile Include="Utils.cs" />
|
||||
|
|
|
@ -24,4 +24,5 @@
|
|||
<package id="System.Runtime" version="4.3.0" targetFramework="net45" />
|
||||
<package id="System.Runtime.Extensions" version="4.3.0" targetFramework="net45" />
|
||||
<package id="System.Threading" version="4.3.0" targetFramework="net45" />
|
||||
<package id="System.ValueTuple" version="4.3.0" targetFramework="net45" />
|
||||
</packages>
|
Загрузка…
Ссылка в новой задаче