This commit is contained in:
Luong Hoang 2015-06-17 23:53:51 -04:00
Родитель 39779d01ad
Коммит 7b4297cbbd
47 изменённых файлов: 1143 добавлений и 5291 удалений

Просмотреть файл

@ -0,0 +1,44 @@

namespace MultiWorldTesting
{
/// <summary>
/// The bootstrap exploration class.
/// </summary>
/// <remarks>
/// The Bootstrap explorer randomizes over the actions chosen by a set of
/// default policies. This performs well statistically but can be
/// computationally expensive.
/// </remarks>
/// <typeparam name="TContext">The Context type.</typeparam>
public class BootstrapExplorer<TContext> : IExplorer<TContext>, IConsumePolicies<TContext>
{
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
public BootstrapExplorer(IPolicy<TContext>[] defaultPolicies, uint numActions)
{
// TODO: implement
}
/// <summary>
/// Initializes a bootstrap explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
public BootstrapExplorer(IPolicy<TContext>[] defaultPolicies)
{
// TODO: implement
}
public void UpdatePolicy(IPolicy<TContext>[] newPolicies)
{
// TODO: implement
}
public void EnableExplore(bool explore)
{
// TODO: implement
}
};
}

Просмотреть файл

@ -0,0 +1,45 @@

namespace MultiWorldTesting
{
/// <summary>
/// The epsilon greedy exploration class.
/// </summary>
/// <remarks>
/// This is a good choice if you have no idea which actions should be preferred.
/// Epsilon greedy is also computationally cheap.
/// </remarks>
/// <typeparam name="TContext">The Context type.</typeparam>
public class EpsilonGreedyExplorer<TContext> : IExplorer<TContext>, IConsumePolicy<TContext>
{
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
/// <param name="epsilon">The probability of a random exploration.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
public EpsilonGreedyExplorer(IPolicy<TContext> defaultPolicy, float epsilon, uint numActions)
{
// TODO: implement
}
/// <summary>
/// Initializes an epsilon greedy explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
/// <param name="epsilon">The probability of a random exploration.</param>
public EpsilonGreedyExplorer(IPolicy<TContext> defaultPolicy, float epsilon)
{
// TODO: implement
}
public void UpdatePolicy(IPolicy<TContext> newPolicy)
{
// TODO: implement
}
public void EnableExplore(bool explore)
{
// TODO: implement
}
};
}

60
Explore/Explore.csproj Normal file
Просмотреть файл

@ -0,0 +1,60 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<MinimumVisualStudioVersion>10.0</MinimumVisualStudioVersion>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{6D245816-6016-49B6-9E37-A0BF0D2A736A}</ProjectGuid>
<OutputType>Library</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>MultiWorldTesting</RootNamespace>
<AssemblyName>Explore</AssemblyName>
<DefaultLanguage>en-US</DefaultLanguage>
<FileAlignment>512</FileAlignment>
<ProjectTypeGuids>{786C830F-07A1-408B-BD7F-6EE04809D6DB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
<TargetFrameworkProfile>Profile259</TargetFrameworkProfile>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
<Optimize>false</Optimize>
<OutputPath>bin\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<DebugType>pdbonly</DebugType>
<Optimize>true</Optimize>
<OutputPath>bin\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<ItemGroup>
<!-- A reference to the entire .NET Framework is automatically included -->
</ItemGroup>
<ItemGroup>
<Compile Include="BootstrapExplorer.cs" />
<Compile Include="EpsilonGreedyExplorer.cs" />
<Compile Include="Feature.cs" />
<Compile Include="GenericExplorer.cs" />
<Compile Include="Interface.cs" />
<Compile Include="MwtExplorer.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="SimpleContext.cs" />
<Compile Include="SoftmaxExplorer.cs" />
<Compile Include="StringRecorder.cs" />
<Compile Include="TauFirstExplorer.cs" />
</ItemGroup>
<Import Project="$(MSBuildExtensionsPath32)\Microsoft\Portable\$(TargetFrameworkVersion)\Microsoft.Portable.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
<Target Name="BeforeBuild">
</Target>
<Target Name="AfterBuild">
</Target>
-->
</Project>

15
Explore/Feature.cs Normal file
Просмотреть файл

@ -0,0 +1,15 @@
using System;
using System.Runtime.InteropServices;
namespace MultiWorldTesting
{
/// <summary>
/// Represents a feature in a sparse array.
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct Feature
{
public float Value;
public UInt32 Id;
};
}

Просмотреть файл

@ -0,0 +1,43 @@

namespace MultiWorldTesting
{
/// <summary>
/// The generic exploration class.
/// </summary>
/// <remarks>
/// GenericExplorer provides complete flexibility. You can create any
/// distribution over actions desired, and it will draw from that.
/// </remarks>
/// <typeparam name="TContext">The Context type.</typeparam>
public class GenericExplorer<TContext> : IExplorer<TContext>, IConsumeScorer<TContext>
{
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultScorer">A function which outputs the probability of each action.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
public GenericExplorer(IScorer<TContext> defaultScorer, uint numActions)
{
// TODO: implement
}
/// <summary>
/// Initializes a generic explorer with variable number of actions.
/// </summary>
/// <param name="defaultScorer">A function which outputs the probability of each action.</param>
public GenericExplorer(IScorer<TContext> defaultScorer)
{
// TODO: implement
}
public void UpdateScorer(IScorer<TContext> newScorer)
{
// TODO: implement
}
public void EnableExplore(bool explore)
{
// TODO: implement
}
};
}

96
Explore/Interface.cs Normal file
Просмотреть файл

@ -0,0 +1,96 @@
using System;
using System.Collections.Generic;
namespace MultiWorldTesting
{
/// <summary>
/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
/// </summary>
/// <typeparam name="TContext">The Context type.</typeparam>
/// <remarks>
/// Exploration data is specified as a set of tuples (context, action, probability, key) as described below. An
/// application passes an IRecorder object to the @MwtExplorer constructor. See
/// @StringRecorder for a sample IRecorder object.
/// </remarks>
public interface IRecorder<TContext>
{
/// <summary>
/// Records the exploration data associated with a given decision.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <param name="action">Chosen by an exploration algorithm given context.</param>
/// <param name="probability">The probability of the chosen action given context.</param>
/// <param name="uniqueKey">A user-defined identifer for the decision.</param>
void Record(TContext context, uint action, float probability, string uniqueKey);
};
/// <summary>
/// Exposes a method for choosing an action given a generic context. IPolicy objects are
/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
/// </summary>
/// <typeparam name="TContext">The Context type.</typeparam>
public interface IPolicy<TContext>
{
/// <summary>
/// Determines the action to take for a given context.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <returns>Index of the action to take (1-based)</returns>
uint ChooseAction(TContext context);
};
/// <summary>
/// Exposes a method for specifying a score (weight) for each action given a generic context.
/// </summary>
/// <typeparam name="TContext">The Context type.</typeparam>
public interface IScorer<TContext>
{
/// <summary>
/// Determines the score of each action for a given context.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <returns>Vector of scores indexed by action (1-based).</returns>
List<float> ScoreActions(TContext context);
};
/// <summary>
/// Represents a context interface with variable number of actions which is
/// enforced if exploration algorithm is initialized in variable number of actions mode.
/// </summary>
public interface IVariableActionContext
{
/// <summary>
/// Gets the number of actions for the current context.
/// </summary>
/// <returns>The number of actions available for the current context.</returns>
UInt32 GetNumberOfActions();
};
public interface IExplorer<TContext>
{
void EnableExplore(bool explore);
};
public interface IConsumePolicy<TContext>
{
void UpdatePolicy(IPolicy<TContext> newPolicy);
};
public interface IConsumePolicies<TContext>
{
void UpdatePolicy(IPolicy<TContext>[] newPolicies);
};
public interface IConsumeScorer<TContext>
{
void UpdateScorer(IScorer<TContext> newScorer);
};
public interface IStringContext
{
string ToString();
};
}

34
Explore/MwtExplorer.cs Normal file
Просмотреть файл

@ -0,0 +1,34 @@

namespace MultiWorldTesting
{
/// <summary>
/// The top level MwtExplorer class. Using this makes sure that the
/// right bits are recorded and good random actions are chosen.
/// </summary>
/// <typeparam name="TContext">The Context type.</typeparam>
public class MwtExplorer<TContext>
{
/// <summary>
/// Constructor.
/// </summary>
/// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
/// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
public MwtExplorer(string appId, IRecorder<TContext> recorder)
{
// TODO: implement
}
/// <summary>
/// Choose_Action should be drop-in replacement for any existing policy function.
/// </summary>
/// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
/// <param name="uniqueKey">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
/// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
/// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
public uint ChooseAction(IExplorer<TContext> explorer, string uniqueKey, TContext context)
{
// TODO: implement
return 0;
}
};
}

Просмотреть файл

@ -0,0 +1,30 @@
using System.Resources;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("Explore")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("")]
[assembly: AssemblyProduct("Explore")]
[assembly: AssemblyCopyright("Copyright © 2015")]
[assembly: AssemblyTrademark("")]
[assembly: AssemblyCulture("")]
[assembly: NeutralResourcesLanguage("en")]
// Version information for an assembly consists of the following four values:
//
// Major Version
// Minor Version
// Build Number
// Revision
//
// You can specify all the values or you can default the Build and Revision Numbers
// by using the '*' as shown below:
// [assembly: AssemblyVersion("1.0.*")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: AssemblyFileVersion("1.0.0.0")]

23
Explore/SimpleContext.cs Normal file
Просмотреть файл

@ -0,0 +1,23 @@

namespace MultiWorldTesting
{
/// <summary>
/// A sample context class that stores a vector of Features.
/// </summary>
public class SimpleContext : IStringContext
{
public SimpleContext(Feature[] features)
{
this.Features = features;
}
public string ToString()
{
return null;
}
public Feature[] GetFeatures() { return Features; }
internal Feature[] Features;
};
}

Просмотреть файл

@ -0,0 +1,46 @@

namespace MultiWorldTesting
{
/// <summary>
/// The softmax exploration class.
/// </summary>
/// <remarks>
/// In some cases, different actions have a different scores, and you
/// would prefer to choose actions with large scores. Softmax allows
/// you to do that.
/// </remarks>
/// <typeparam name="TContext">The Context type.</typeparam>
public class SoftmaxExplorer<TContext> : IExplorer<TContext>, IConsumeScorer<TContext>
{
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultScorer">A function which outputs a score for each action.</param>
/// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
public SoftmaxExplorer(IScorer<TContext> defaultScorer, float lambda, uint numActions)
{
// TODO: implement
}
/// <summary>
/// Initializes a softmax explorer with variable number of actions.
/// </summary>
/// <param name="defaultScorer">A function which outputs a score for each action.</param>
/// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
public SoftmaxExplorer(IScorer<TContext> defaultScorer, float lambda)
{
// TODO: implement
}
public void UpdateScorer(IScorer<TContext> newScorer)
{
// TODO: implement
}
public void EnableExplore(bool explore)
{
// TODO: implement
}
};
}

42
Explore/StringRecorder.cs Normal file
Просмотреть файл

@ -0,0 +1,42 @@

namespace MultiWorldTesting
{
/// <summary>
/// A sample recorder class that converts the exploration tuple into string format.
/// </summary>
/// <typeparam name="TContext">The Context type.</typeparam>
public class StringRecorder<TContext> : IRecorder<TContext>
where TContext : IStringContext
{
public StringRecorder()
{
// TODO: implement
}
/// <summary>
/// Records the exploration data associated with a given decision.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <param name="action">Chosen by an exploration algorithm given context.</param>
/// <param name="probability">The probability of the chosen action given context.</param>
/// <param name="uniqueKey">A user-defined identifer for the decision.</param>
public void Record(TContext context, uint action, float probability, string uniqueKey)
{
// TODO: implement
}
/// <summary>
/// Gets the content of the recording so far as a string and optionally clears internal content.
/// </summary>
/// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
/// <returns>
/// A string with recording content.
/// </returns>
public string GetRecording(bool flush = false)
{
// TODO: implement
return null;
}
};
}

Просмотреть файл

@ -0,0 +1,45 @@

namespace MultiWorldTesting
{
/// <summary>
/// The tau-first exploration class.
/// </summary>
/// <remarks>
/// The tau-first explorer collects precisely tau uniform random
/// exploration events, and then uses the default policy.
/// </remarks>
/// <typeparam name="TContext">The Context type.</typeparam>
public class TauFirstExplorer<TContext> : IExplorer<TContext>, IConsumePolicy<TContext>
{
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicy">A default policy after randomization finishes.</param>
/// <param name="tau">The number of events to be uniform over.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
public TauFirstExplorer(IPolicy<TContext> defaultPolicy, uint tau, uint numActions)
{
// TODO: implement
}
/// <summary>
/// Initializes a tau-first explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicy">A default policy after randomization finishes.</param>
/// <param name="tau">The number of events to be uniform over.</param>
public TauFirstExplorer(IPolicy<TContext> defaultPolicy, uint tau)
{
// TODO: implement
}
public void UpdatePolicy(IPolicy<TContext> newPolicy)
{
// TODO: implement
}
public void EnableExplore(bool explore)
{
// TODO: implement
}
};
}

Просмотреть файл

@ -1,4 +0,0 @@
This is the root of the exploration library and client-side decision service.
explore_sample.cpp shows how to use the exploration library which is a
header-only include in C++.

Просмотреть файл

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8" ?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.5" />
</startup>
<?xml version="1.0" encoding="utf-8" ?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.5" />
</startup>
</configuration>

Просмотреть файл

@ -8,21 +8,21 @@ namespace cs_test
{
class ExploreOnlySample
{
/// <summary>
/// Example of a custom context.
/// <summary>
/// Example of a custom context.
/// </summary>
class MyContext { }
/// <summary>
/// Example of a custom recorder which implements the IRecorder<MyContext>,
/// declaring that this recorder only interacts with MyContext objects.
class MyContext { }
/// <summary>
/// Example of a custom recorder which implements the IRecorder<MyContext>,
/// declaring that this recorder only interacts with MyContext objects.
/// </summary>
class MyRecorder : IRecorder<MyContext>
{
public void Record(MyContext context, UInt32 action, float probability, string uniqueKey)
{
// Stores the tuple internally in a vector that could be used later for other purposes.
interactions.Add(new Interaction<MyContext>()
{
// Stores the tuple internally in a vector that could be used later for other purposes.
interactions.Add(new Interaction<MyContext>()
{
Context = context,
Action = action,
@ -39,9 +39,9 @@ namespace cs_test
private List<Interaction<MyContext>> interactions = new List<Interaction<MyContext>>();
}
/// <summary>
/// Example of a custom policy which implements the IPolicy<MyContext>,
/// declaring that this policy only interacts with MyContext objects.
/// <summary>
/// Example of a custom policy which implements the IPolicy<MyContext>,
/// declaring that this policy only interacts with MyContext objects.
/// </summary>
class MyPolicy : IPolicy<MyContext>
{
@ -53,30 +53,30 @@ namespace cs_test
}
public uint ChooseAction(MyContext context)
{
// Always returns the same action regardless of context
{
// Always returns the same action regardless of context
return 5;
}
private int index;
}
/// <summary>
/// Example of a custom policy which implements the IPolicy<SimpleContext>,
/// declaring that this policy only interacts with SimpleContext objects.
/// <summary>
/// Example of a custom policy which implements the IPolicy<SimpleContext>,
/// declaring that this policy only interacts with SimpleContext objects.
/// </summary>
class StringPolicy : IPolicy<SimpleContext>
{
public uint ChooseAction(SimpleContext context)
{
// Always returns the same action regardless of context
{
// Always returns the same action regardless of context
return 1;
}
}
/// <summary>
/// Example of a custom scorer which implements the IScorer<MyContext>,
/// declaring that this scorer only interacts with MyContext objects.
/// <summary>
/// Example of a custom scorer which implements the IScorer<MyContext>,
/// declaring that this scorer only interacts with MyContext objects.
/// </summary>
class MyScorer : IScorer<MyContext>
{
@ -91,9 +91,9 @@ namespace cs_test
private uint numActions;
}
/// <summary>
/// Represents a tuple <context, action, probability, key>.
/// </summary>
/// <summary>
/// Represents a tuple <context, action, probability, key>.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
struct Interaction<Ctx>
{
@ -109,12 +109,12 @@ namespace cs_test
if (exploration_type == "greedy")
{
// Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types
// Creates a recorder of built-in StringRecorder type for string serialization
StringRecorder<SimpleContext> recorder = new StringRecorder<SimpleContext>();
// Creates an MwtExplorer instance using the recorder above
// Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types
// Creates a recorder of built-in StringRecorder type for string serialization
StringRecorder<SimpleContext> recorder = new StringRecorder<SimpleContext>();
// Creates an MwtExplorer instance using the recorder above
MwtExplorer<SimpleContext> mwtt = new MwtExplorer<SimpleContext>("mwt", recorder);
// Creates a policy that interacts with SimpleContext type
@ -123,20 +123,20 @@ namespace cs_test
uint numActions = 10;
float epsilon = 0.2f;
// Creates an Epsilon-Greedy explorer using the specified settings
EpsilonGreedyExplorer<SimpleContext> explorer = new EpsilonGreedyExplorer<SimpleContext>(policy, epsilon, numActions);
// Creates a context of built-in SimpleContext type
EpsilonGreedyExplorer<SimpleContext> explorer = new EpsilonGreedyExplorer<SimpleContext>(policy, epsilon, numActions);
// Creates a context of built-in SimpleContext type
SimpleContext context = new SimpleContext(new Feature[] {
new Feature() { Id = 1, Value = 0.5f },
new Feature() { Id = 4, Value = 1.3f },
new Feature() { Id = 9, Value = -0.5f },
});
// Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
// using a sample string to uniquely identify this event
string uniqueKey = "eventid";
uint action = mwtt.ChooseAction(explorer, uniqueKey, context);
});
// Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
// using a sample string to uniquely identify this event
string uniqueKey = "eventid";
uint action = mwtt.ChooseAction(explorer, uniqueKey, context);
Console.WriteLine(recorder.GetRecording());
return;

Просмотреть файл

@ -1,86 +1,66 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{7081D542-AE64-485D-9087-79194B958499}</ProjectGuid>
<OutputType>Exe</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>ExploreSample</RootNamespace>
<AssemblyName>ExploreSample</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x86'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\x86\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>x86</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x86'">
<OutputPath>bin\x86\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>x86</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\x64\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>x64</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
<OutputPath>bin\x64\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>x64</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Xml.Linq" />
<Reference Include="System.Data.DataSetExtensions" />
<Reference Include="Microsoft.CSharp" />
<Reference Include="System.Data" />
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="ExploreOnlySample.cs" />
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<None Include="App.config" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\clr\explore_clr.vcxproj">
<Project>{8400da16-1f46-4a31-a126-bbe16f62bfd7}</Project>
<Name>explore_clr</Name>
</ProjectReference>
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
<Target Name="BeforeBuild">
</Target>
<Target Name="AfterBuild">
</Target>
-->
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{7081D542-AE64-485D-9087-79194B958499}</ProjectGuid>
<OutputType>Exe</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>ExploreSample</RootNamespace>
<AssemblyName>ExploreSample</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|AnyCPU'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>AnyCPU</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|AnyCPU'">
<OutputPath>bin\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>AnyCPU</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
<Prefer32Bit>true</Prefer32Bit>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Xml.Linq" />
<Reference Include="System.Data.DataSetExtensions" />
<Reference Include="Microsoft.CSharp" />
<Reference Include="System.Data" />
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="ExploreOnlySample.cs" />
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<None Include="App.config" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Explore\Explore.csproj">
<Project>{6d245816-6016-49b6-9e37-a0bf0d2a736a}</Project>
<Name>Explore</Name>
</ProjectReference>
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
<Target Name="BeforeBuild">
</Target>
<Target Name="AfterBuild">
</Target>
-->
</Project>

Просмотреть файл

@ -1,16 +1,16 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using MultiWorldTesting;
namespace ExploreSample
{
class Program
{
public static void Main(string[] args)
{
cs_test.ExploreOnlySample.Run();
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using MultiWorldTesting;
namespace ExploreSample
{
class Program
{
public static void Main(string[] args)
{
cs_test.ExploreOnlySample.Run();
}
}
}

Просмотреть файл

@ -1,36 +1,36 @@
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("ExploreSample")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("")]
[assembly: AssemblyProduct("ExploreSample")]
[assembly: AssemblyCopyright("Copyright © 2014")]
[assembly: AssemblyTrademark("")]
[assembly: AssemblyCulture("")]
// Setting ComVisible to false makes the types in this assembly not visible
// to COM components. If you need to access a type in this assembly from
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("767d4e7c-6acc-4b46-9eac-e86ab079625a")]
// Version information for an assembly consists of the following four values:
//
// Major Version
// Minor Version
// Build Number
// Revision
//
// You can specify all the values or you can default the Build and Revision Numbers
// by using the '*' as shown below:
// [assembly: AssemblyVersion("1.0.*")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: AssemblyFileVersion("1.0.0.0")]
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("ExploreSample")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("")]
[assembly: AssemblyProduct("ExploreSample")]
[assembly: AssemblyCopyright("Copyright © 2014")]
[assembly: AssemblyTrademark("")]
[assembly: AssemblyCulture("")]
// Setting ComVisible to false makes the types in this assembly not visible
// to COM components. If you need to access a type in this assembly from
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("767d4e7c-6acc-4b46-9eac-e86ab079625a")]
// Version information for an assembly consists of the following four values:
//
// Major Version
// Minor Version
// Build Number
// Revision
//
// You can specify all the values or you can default the Build and Revision Numbers
// by using the '*' as shown below:
// [assembly: AssemblyVersion("1.0.*")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: AssemblyFileVersion("1.0.0.0")]

Просмотреть файл

@ -1,112 +1,94 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{CB0C6B20-560C-4822-8EF6-DA999A64B542}</ProjectGuid>
<OutputType>Library</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>ExploreTests</RootNamespace>
<AssemblyName>ExploreTests</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<ProjectTypeGuids>{3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
<ReferencePath>$(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages</ReferencePath>
<IsCodedUITest>False</IsCodedUITest>
<TestProjectType>UnitTest</TestProjectType>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\x64\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>x64</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'">
<OutputPath>bin\x64\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>x64</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x86'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\x86\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>x86</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x86'">
<OutputPath>bin\x86\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>x86</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
</ItemGroup>
<Choose>
<When Condition="('$(VisualStudioVersion)' == '10.0' or '$(VisualStudioVersion)' == '') and '$(TargetFrameworkVersion)' == 'v3.5'">
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.UnitTestFramework, Version=10.1.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL" />
</ItemGroup>
</When>
<Otherwise>
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.UnitTestFramework" />
</ItemGroup>
</Otherwise>
</Choose>
<ItemGroup>
<Compile Include="MWTExploreTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\clr\explore_clr.vcxproj">
<Project>{8400da16-1f46-4a31-a126-bbe16f62bfd7}</Project>
<Name>explore_clr</Name>
</ProjectReference>
</ItemGroup>
<Choose>
<When Condition="'$(VisualStudioVersion)' == '10.0' And '$(IsCodedUITest)' == 'True'">
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.CodedUITestFramework, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITest.Common, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITest.Extension, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITesting, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
</ItemGroup>
</When>
</Choose>
<Import Project="$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets" Condition="Exists('$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets')" />
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<PropertyGroup>
<PostBuildEvent>
</PostBuildEvent>
</PropertyGroup>
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{CB0C6B20-560C-4822-8EF6-DA999A64B542}</ProjectGuid>
<OutputType>Library</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>ExploreTests</RootNamespace>
<AssemblyName>ExploreTests</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<ProjectTypeGuids>{3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}</ProjectTypeGuids>
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
<ReferencePath>$(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages</ReferencePath>
<IsCodedUITest>False</IsCodedUITest>
<TestProjectType>UnitTest</TestProjectType>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|AnyCPU'">
<DebugSymbols>true</DebugSymbols>
<OutputPath>bin\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<DebugType>full</DebugType>
<PlatformTarget>AnyCPU</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|AnyCPU'">
<OutputPath>bin\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<Optimize>true</Optimize>
<DebugType>pdbonly</DebugType>
<PlatformTarget>AnyCPU</PlatformTarget>
<ErrorReport>prompt</ErrorReport>
<CodeAnalysisRuleSet>MinimumRecommendedRules.ruleset</CodeAnalysisRuleSet>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
</ItemGroup>
<Choose>
<When Condition="('$(VisualStudioVersion)' == '10.0' or '$(VisualStudioVersion)' == '') and '$(TargetFrameworkVersion)' == 'v3.5'">
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.UnitTestFramework, Version=10.1.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL" />
</ItemGroup>
</When>
<Otherwise>
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.UnitTestFramework" />
</ItemGroup>
</Otherwise>
</Choose>
<ItemGroup>
<Compile Include="MWTExploreTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Explore\Explore.csproj">
<Project>{6d245816-6016-49b6-9e37-a0bf0d2a736a}</Project>
<Name>Explore</Name>
</ProjectReference>
</ItemGroup>
<Choose>
<When Condition="'$(VisualStudioVersion)' == '10.0' And '$(IsCodedUITest)' == 'True'">
<ItemGroup>
<Reference Include="Microsoft.VisualStudio.QualityTools.CodedUITestFramework, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITest.Common, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITest.Extension, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
<Reference Include="Microsoft.VisualStudio.TestTools.UITesting, Version=10.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
<Private>False</Private>
</Reference>
</ItemGroup>
</When>
</Choose>
<Import Project="$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets" Condition="Exists('$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets')" />
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<PropertyGroup>
<PostBuildEvent>
</PostBuildEvent>
</PropertyGroup>
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
<Target Name="BeforeBuild">
</Target>
<Target Name="AfterBuild">
</Target>
-->
-->
</Project>

Просмотреть файл

@ -1,9 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'">
<EnableUnmanagedDebugging>true</EnableUnmanagedDebugging>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x86'">
<EnableUnmanagedDebugging>true</EnableUnmanagedDebugging>
</PropertyGroup>
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|AnyCPU'">
<EnableUnmanagedDebugging>true</EnableUnmanagedDebugging>
</PropertyGroup>
</Project>

Просмотреть файл

@ -17,238 +17,238 @@ namespace ExploreTests
[TestMethod]
public void EpsilonGreedy()
{
uint numActions = 10;
float epsilon = 0f;
var policy = new TestPolicy<TestContext>();
var testContext = new TestContext();
var explorer = new EpsilonGreedyExplorer<TestContext>(policy, epsilon, numActions);
EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
}
[TestMethod]
public void EpsilonGreedyFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
float epsilon = 0f;
var policy = new TestPolicy<TestVarContext>();
var testContext = new TestVarContext(numActions);
var explorer = new EpsilonGreedyExplorer<TestVarContext>(policy, epsilon);
EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
}
private static void EpsilonGreedyWithContext<TContext>(uint numActions, TContext testContext, TestPolicy<TContext> policy, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
TestRecorder<TContext> recorder = new TestRecorder<TContext>();
MwtExplorer<TContext> mwtt = new MwtExplorer<TContext>("mwt", recorder);
testContext.Id = 100;
uint expectedAction = policy.ChooseAction(testContext);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(2, interactions.Count);
Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
}
uint numActions = 10;
float epsilon = 0f;
var policy = new TestPolicy<TestContext>();
var testContext = new TestContext();
var explorer = new EpsilonGreedyExplorer<TestContext>(policy, epsilon, numActions);
EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
}
[TestMethod]
public void EpsilonGreedyFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
float epsilon = 0f;
var policy = new TestPolicy<TestVarContext>();
var testContext = new TestVarContext(numActions);
var explorer = new EpsilonGreedyExplorer<TestVarContext>(policy, epsilon);
EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
}
private static void EpsilonGreedyWithContext<TContext>(uint numActions, TContext testContext, TestPolicy<TContext> policy, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
TestRecorder<TContext> recorder = new TestRecorder<TContext>();
MwtExplorer<TContext> mwtt = new MwtExplorer<TContext>("mwt", recorder);
testContext.Id = 100;
uint expectedAction = policy.ChooseAction(testContext);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(2, interactions.Count);
Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
}
}
[TestMethod]
public void TauFirst()
{
uint numActions = 10;
uint tau = 0;
TestContext testContext = new TestContext() { Id = 100 };
var policy = new TestPolicy<TestContext>();
var explorer = new TauFirstExplorer<TestContext>(policy, tau, numActions);
uint numActions = 10;
uint tau = 0;
TestContext testContext = new TestContext() { Id = 100 };
var policy = new TestPolicy<TestContext>();
var explorer = new TauFirstExplorer<TestContext>(policy, tau, numActions);
TauFirstWithContext(numActions, testContext, policy, explorer);
}
[TestMethod]
public void TauFirstFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
uint tau = 0;
var testContext = new TestVarContext(numActions) { Id = 100 };
var policy = new TestPolicy<TestVarContext>();
var explorer = new TauFirstExplorer<TestVarContext>(policy, tau);
TauFirstWithContext(numActions, testContext, policy, explorer);
}
private static void TauFirstWithContext<TContext>(uint numActions, TContext testContext, TestPolicy<TContext> policy, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint expectedAction = policy.ChooseAction(testContext);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(0, interactions.Count);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
}
}
[TestMethod]
public void TauFirstFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
uint tau = 0;
var testContext = new TestVarContext(numActions) { Id = 100 };
var policy = new TestPolicy<TestVarContext>();
var explorer = new TauFirstExplorer<TestVarContext>(policy, tau);
TauFirstWithContext(numActions, testContext, policy, explorer);
}
private static void TauFirstWithContext<TContext>(uint numActions, TContext testContext, TestPolicy<TContext> policy, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint expectedAction = policy.ChooseAction(testContext);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(0, interactions.Count);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
Assert.AreEqual(expectedAction, chosenAction);
}
}
[TestMethod]
public void Bootstrap()
{
uint numActions = 10;
uint numbags = 2;
TestContext testContext1 = new TestContext() { Id = 99 };
TestContext testContext2 = new TestContext() { Id = 100 };
var policies = new TestPolicy<TestContext>[numbags];
for (int i = 0; i < numbags; i++)
{
policies[i] = new TestPolicy<TestContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestContext>(policies, numActions);
uint numActions = 10;
uint numbags = 2;
TestContext testContext1 = new TestContext() { Id = 99 };
TestContext testContext2 = new TestContext() { Id = 100 };
var policies = new TestPolicy<TestContext>[numbags];
for (int i = 0; i < numbags; i++)
{
policies[i] = new TestPolicy<TestContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestContext>(policies, numActions);
BootstrapWithContext(numActions, testContext1, testContext2, policies, explorer);
}
[TestMethod]
public void BootstrapFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
uint numbags = 2;
var testContext1 = new TestVarContext(numActions) { Id = 99 };
var testContext2 = new TestVarContext(numActions) { Id = 100 };
var policies = new TestPolicy<TestVarContext>[numbags];
for (int i = 0; i < numbags; i++)
{
policies[i] = new TestPolicy<TestVarContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestVarContext>(policies);
BootstrapWithContext(numActions, testContext1, testContext2, policies, explorer);
}
private static void BootstrapWithContext<TContext>(uint numActions, TContext testContext1, TContext testContext2, TestPolicy<TContext>[] policies, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint expectedAction = policies[0].ChooseAction(testContext1);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1);
Assert.AreEqual(expectedAction, chosenAction);
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext2);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(2, interactions.Count);
Assert.AreEqual(testContext1.Id, interactions[0].Context.Id);
Assert.AreEqual(testContext2.Id, interactions[1].Context.Id);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1);
Assert.AreEqual(expectedAction, chosenAction);
}
}
[TestMethod]
public void BootstrapFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
uint numbags = 2;
var testContext1 = new TestVarContext(numActions) { Id = 99 };
var testContext2 = new TestVarContext(numActions) { Id = 100 };
var policies = new TestPolicy<TestVarContext>[numbags];
for (int i = 0; i < numbags; i++)
{
policies[i] = new TestPolicy<TestVarContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestVarContext>(policies);
BootstrapWithContext(numActions, testContext1, testContext2, policies, explorer);
}
private static void BootstrapWithContext<TContext>(uint numActions, TContext testContext1, TContext testContext2, TestPolicy<TContext>[] policies, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint expectedAction = policies[0].ChooseAction(testContext1);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1);
Assert.AreEqual(expectedAction, chosenAction);
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext2);
Assert.AreEqual(expectedAction, chosenAction);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(2, interactions.Count);
Assert.AreEqual(testContext1.Id, interactions[0].Context.Id);
Assert.AreEqual(testContext2.Id, interactions[1].Context.Id);
// Verify that policy action is chosen all the time
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext1);
Assert.AreEqual(expectedAction, chosenAction);
}
}
[TestMethod]
public void Softmax()
{
uint numActions = 10;
float lambda = 0.5f;
uint numActionsCover = 100;
float C = 5;
var scorer = new TestScorer<TestContext>(numActions);
var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions);
uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions);
var contexts = new TestContext[numDecisions];
for (int i = 0; i < numDecisions; i++)
{
contexts[i] = new TestContext { Id = i };
}
{
uint numActions = 10;
float lambda = 0.5f;
uint numActionsCover = 100;
float C = 5;
var scorer = new TestScorer<TestContext>(numActions);
var explorer = new SoftmaxExplorer<TestContext>(scorer, lambda, numActions);
uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions);
var contexts = new TestContext[numDecisions];
for (int i = 0; i < numDecisions; i++)
{
contexts[i] = new TestContext { Id = i };
}
SoftmaxWithContext(numActions, explorer, contexts);
}
[TestMethod]
public void SoftmaxFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
float lambda = 0.5f;
uint numActionsCover = 100;
float C = 5;
var scorer = new TestScorer<TestVarContext>(numActions);
var explorer = new SoftmaxExplorer<TestVarContext>(scorer, lambda);
uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions);
var contexts = new TestVarContext[numDecisions];
for (int i = 0; i < numDecisions; i++)
{
contexts[i] = new TestVarContext(numActions) { Id = i };
}
SoftmaxWithContext(numActions, explorer, contexts);
}
private static void SoftmaxWithContext<TContext>(uint numActions, IExplorer<TContext> explorer, TContext[] contexts)
where TContext : TestContext
{
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint[] actions = new uint[numActions];
Random rand = new Random();
for (uint i = 0; i < contexts.Length; i++)
{
uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), contexts[i]);
actions[chosenAction - 1]++; // action id is one-based
}
for (uint i = 0; i < numActions; i++)
{
Assert.IsTrue(actions[i] > 0);
}
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(contexts.Length, interactions.Count);
for (int i = 0; i < contexts.Length; i++)
{
Assert.AreEqual(i, interactions[i].Context.Id);
}
}
[TestMethod]
public void SoftmaxFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
float lambda = 0.5f;
uint numActionsCover = 100;
float C = 5;
var scorer = new TestScorer<TestVarContext>(numActions);
var explorer = new SoftmaxExplorer<TestVarContext>(scorer, lambda);
uint numDecisions = (uint)(numActions * Math.Log(numActions * 1.0) + Math.Log(numActionsCover * 1.0 / numActions) * C * numActions);
var contexts = new TestVarContext[numDecisions];
for (int i = 0; i < numDecisions; i++)
{
contexts[i] = new TestVarContext(numActions) { Id = i };
}
SoftmaxWithContext(numActions, explorer, contexts);
}
private static void SoftmaxWithContext<TContext>(uint numActions, IExplorer<TContext> explorer, TContext[] contexts)
where TContext : TestContext
{
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint[] actions = new uint[numActions];
Random rand = new Random();
for (uint i = 0; i < contexts.Length; i++)
{
uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), contexts[i]);
actions[chosenAction - 1]++; // action id is one-based
}
for (uint i = 0; i < numActions; i++)
{
Assert.IsTrue(actions[i] > 0);
}
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(contexts.Length, interactions.Count);
for (int i = 0; i < contexts.Length; i++)
{
Assert.AreEqual(i, interactions[i].Context.Id);
}
}
[TestMethod]
@ -276,125 +276,125 @@ namespace ExploreTests
// Scores are not equal therefore probabilities should not be uniform
Assert.AreNotEqual(interactions[i].Probability, 1.0f / numActions);
Assert.AreEqual(100 + i, interactions[i].Context.Id);
}
// Verify that policy action is chosen all the time
TestContext context = new TestContext { Id = 100 };
List<float> scores = scorer.ScoreActions(context);
float maxScore = 0;
uint highestScoreAction = 0;
for (int i = 0; i < scores.Count; i++)
{
if (maxScore < scores[i])
{
maxScore = scores[i];
highestScoreAction = (uint)i + 1;
}
}
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = (int)i });
Assert.AreEqual(highestScoreAction, chosenAction);
}
// Verify that policy action is chosen all the time
TestContext context = new TestContext { Id = 100 };
List<float> scores = scorer.ScoreActions(context);
float maxScore = 0;
uint highestScoreAction = 0;
for (int i = 0; i < scores.Count; i++)
{
if (maxScore < scores[i])
{
maxScore = scores[i];
highestScoreAction = (uint)i + 1;
}
}
explorer.EnableExplore(false);
for (int i = 0; i < 1000; i++)
{
uint chosenAction = mwtt.ChooseAction(explorer, rand.NextDouble().ToString(), new TestContext() { Id = (int)i });
Assert.AreEqual(highestScoreAction, chosenAction);
}
}
[TestMethod]
public void Generic()
{
uint numActions = 10;
TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions);
TestContext testContext = new TestContext() { Id = 100 };
var explorer = new GenericExplorer<TestContext>(scorer, numActions);
uint numActions = 10;
TestScorer<TestContext> scorer = new TestScorer<TestContext>(numActions);
TestContext testContext = new TestContext() { Id = 100 };
var explorer = new GenericExplorer<TestContext>(scorer, numActions);
GenericWithContext(numActions, testContext, explorer);
}
[TestMethod]
public void GenericFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
var scorer = new TestScorer<TestVarContext>(numActions);
var testContext = new TestVarContext(numActions) { Id = 100 };
var explorer = new GenericExplorer<TestVarContext>(scorer);
GenericWithContext(numActions, testContext, explorer);
}
private static void GenericWithContext<TContext>(uint numActions, TContext testContext, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(1, interactions.Count);
Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
}
[TestMethod]
public void UsageBadVariableActionContext()
{
int numExceptionsCaught = 0;
int numExceptionsExpected = 5;
var tryCatchArgumentException = (Action<Action>)((action) => {
try
{
action();
}
catch (ArgumentException ex)
{
if (ex.ParamName.ToLower() == "ctx")
{
numExceptionsCaught++;
}
}
});
tryCatchArgumentException(() => {
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policy = new TestPolicy<TestContext>();
var explorer = new EpsilonGreedyExplorer<TestContext>(policy, 0.2f);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policy = new TestPolicy<TestContext>();
var explorer = new TauFirstExplorer<TestContext>(policy, 10);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policies = new TestPolicy<TestContext>[2];
for (int i = 0; i < 2; i++)
{
policies[i] = new TestPolicy<TestContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestContext>(policies);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var scorer = new TestScorer<TestContext>(10);
var explorer = new SoftmaxExplorer<TestContext>(scorer, 0.5f);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var scorer = new TestScorer<TestContext>(10);
var explorer = new GenericExplorer<TestContext>(scorer);
mwt.ChooseAction(explorer, "key", new TestContext());
});
Assert.AreEqual(numExceptionsExpected, numExceptionsCaught);
}
[TestMethod]
public void GenericFixedActionUsingVariableActionInterface()
{
uint numActions = 10;
var scorer = new TestScorer<TestVarContext>(numActions);
var testContext = new TestVarContext(numActions) { Id = 100 };
var explorer = new GenericExplorer<TestVarContext>(scorer);
GenericWithContext(numActions, testContext, explorer);
}
private static void GenericWithContext<TContext>(uint numActions, TContext testContext, IExplorer<TContext> explorer)
where TContext : TestContext
{
string uniqueKey = "ManagedTestId";
var recorder = new TestRecorder<TContext>();
var mwtt = new MwtExplorer<TContext>("mwt", recorder);
uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
var interactions = recorder.GetAllInteractions();
Assert.AreEqual(1, interactions.Count);
Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
}
[TestMethod]
public void UsageBadVariableActionContext()
{
int numExceptionsCaught = 0;
int numExceptionsExpected = 5;
var tryCatchArgumentException = (Action<Action>)((action) => {
try
{
action();
}
catch (ArgumentException ex)
{
if (ex.ParamName.ToLower() == "ctx")
{
numExceptionsCaught++;
}
}
});
tryCatchArgumentException(() => {
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policy = new TestPolicy<TestContext>();
var explorer = new EpsilonGreedyExplorer<TestContext>(policy, 0.2f);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policy = new TestPolicy<TestContext>();
var explorer = new TauFirstExplorer<TestContext>(policy, 10);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var policies = new TestPolicy<TestContext>[2];
for (int i = 0; i < 2; i++)
{
policies[i] = new TestPolicy<TestContext>(i * 2);
}
var explorer = new BootstrapExplorer<TestContext>(policies);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var scorer = new TestScorer<TestContext>(10);
var explorer = new SoftmaxExplorer<TestContext>(scorer, 0.5f);
mwt.ChooseAction(explorer, "key", new TestContext());
});
tryCatchArgumentException(() =>
{
var mwt = new MwtExplorer<TestContext>("test", new TestRecorder<TestContext>());
var scorer = new TestScorer<TestContext>(10);
var explorer = new GenericExplorer<TestContext>(scorer);
mwt.ChooseAction(explorer, "key", new TestContext());
});
Assert.AreEqual(numExceptionsExpected, numExceptionsCaught);
}
[TestInitialize]
@ -425,21 +425,21 @@ namespace ExploreTests
get { return id; }
set { id = value; }
}
}
class TestVarContext : TestContext, IVariableActionContext
{
public TestVarContext(uint numberOfActions)
{
NumberOfActions = numberOfActions;
}
public uint GetNumberOfActions()
{
return NumberOfActions;
}
public uint NumberOfActions { get; set; }
}
class TestVarContext : TestContext, IVariableActionContext
{
public TestVarContext(uint numberOfActions)
{
NumberOfActions = numberOfActions;
}
public uint GetNumberOfActions()
{
return NumberOfActions;
}
public uint NumberOfActions { get; set; }
}
class TestRecorder<Ctx> : IRecorder<Ctx>
@ -461,8 +461,8 @@ namespace ExploreTests
}
private List<TestInteraction<Ctx>> interactions = new List<TestInteraction<Ctx>>();
}
}
class TestPolicy<TContext> : IPolicy<TContext>
{
public TestPolicy() : this(-1) { }
@ -470,8 +470,8 @@ namespace ExploreTests
public TestPolicy(int index)
{
this.index = index;
}
}
public uint ChooseAction(TContext context)
{
return 5;

Просмотреть файл

Просмотреть файл

@ -1,182 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{8400DA16-1F46-4A31-A126-BBE16F62BFD7}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>vw_explore_clr_wrapper</RootNamespace>
<ProjectName>explore_clr</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<CLRSupport>true</CLRSupport>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<CLRSupport>true</CLRSupport>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<CLRSupport>true</CLRSupport>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<CLRSupport>true</CLRSupport>
</PropertyGroup>
<PropertyGroup Condition="'$(Platform)'=='x64'">
<BoostIncludeDir>c:\boost\x64\include\boost-1_56</BoostIncludeDir>
<BoostLibDir>c:\boost\x64\lib</BoostLibDir>
<ZlibIncludeDir>..\..\..\zlib-1.2.8</ZlibIncludeDir>
<ZlibLibDir>$(ZlibIncludeDir)\contrib\vstudio\vc11\x64\ZlibStat$(Configuration)</ZlibLibDir>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\static;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<GenerateXMLDocumentationFiles>true</GenerateXMLDocumentationFiles>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\static;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<GenerateXMLDocumentationFiles>true</GenerateXMLDocumentationFiles>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\static;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<GenerateXMLDocumentationFiles>true</GenerateXMLDocumentationFiles>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_WINDOWS;_USRDLL;VW_EXPLORE_CLR_WRAPPER_EXPORTS;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<SDLCheck>true</SDLCheck>
<AdditionalIncludeDirectories>..\static;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<GenerateXMLDocumentationFiles>true</GenerateXMLDocumentationFiles>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<Text Include="ReadMe.txt" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="explore_interface.h" />
<ClInclude Include="explore_interop.h" />
<ClInclude Include="explore_clr_wrapper.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="explore_clr_wrapper.cpp" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\static\explore_static.vcxproj">
<Project>{ace47e98-488c-4cdf-b9f1-36337b2855ad}</Project>
</ProjectReference>
</ItemGroup>
<ItemGroup>
<Reference Include="System.Xml" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -1,36 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter>
<Filter Include="Header Files\Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<Text Include="ReadMe.txt" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="explore_clr_wrapper.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="explore_interop.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="explore_interface.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="explore_clr_wrapper.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -1,18 +0,0 @@
// vw_explore_clr_wrapper.cpp : Defines the exported functions for the DLL application.
//
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include "explore_clr_wrapper.h"
using namespace System;
using namespace System::Collections;
using namespace System::Collections::Generic;
using namespace System::Runtime::InteropServices;
using namespace msclr::interop;
using namespace NativeMultiWorldTesting;
namespace MultiWorldTesting {
}

Просмотреть файл

@ -1,591 +0,0 @@
#pragma once
#include "explore_interop.h"
/*!
* \addtogroup MultiWorldTestingCsharp
* @{
*/
namespace MultiWorldTesting {
/// <summary>
/// The epsilon greedy exploration class.
/// </summary>
/// <remarks>
/// This is a good choice if you have no idea which actions should be preferred.
/// Epsilon greedy is also computationally cheap.
/// </remarks>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class EpsilonGreedyExplorer : public IExplorer<Ctx>, public IConsumePolicy<Ctx>, public PolicyCallback<Ctx>
{
public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
/// <param name="epsilon">The probability of a random exploration.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon, UInt32 numActions)
{
this->defaultPolicy = defaultPolicy;
m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon, (u32)numActions);
}
/// <summary>
/// Initializes an epsilon greedy explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicy">A default function which outputs an action given a context.</param>
/// <param name="epsilon">The probability of a random exploration.</param>
EpsilonGreedyExplorer(IPolicy<Ctx>^ defaultPolicy, float epsilon)
{
if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
{
throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
}
this->defaultPolicy = defaultPolicy;
m_explorer = new NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>(*GetNativePolicy(), epsilon);
}
~EpsilonGreedyExplorer()
{
delete m_explorer;
}
virtual void UpdatePolicy(IPolicy<Ctx>^ newPolicy)
{
this->defaultPolicy = newPolicy;
}
virtual void EnableExplore(bool explore)
{
m_explorer->Enable_Explore(explore);
}
internal:
virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
{
return defaultPolicy->ChooseAction(context);
}
NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* Get()
{
return m_explorer;
}
private:
IPolicy<Ctx>^ defaultPolicy;
NativeMultiWorldTesting::EpsilonGreedyExplorer<NativeContext>* m_explorer;
};
/// <summary>
/// The tau-first exploration class.
/// </summary>
/// <remarks>
/// The tau-first explorer collects precisely tau uniform random
/// exploration events, and then uses the default policy.
/// </remarks>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class TauFirstExplorer : public IExplorer<Ctx>, public IConsumePolicy<Ctx>, public PolicyCallback<Ctx>
{
public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicy">A default policy after randomization finishes.</param>
/// <param name="tau">The number of events to be uniform over.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau, UInt32 numActions)
{
this->defaultPolicy = defaultPolicy;
m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau, (u32)numActions);
}
/// <summary>
/// Initializes a tau-first explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicy">A default policy after randomization finishes.</param>
/// <param name="tau">The number of events to be uniform over.</param>
TauFirstExplorer(IPolicy<Ctx>^ defaultPolicy, UInt32 tau)
{
if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
{
throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
}
this->defaultPolicy = defaultPolicy;
m_explorer = new NativeMultiWorldTesting::TauFirstExplorer<NativeContext>(*GetNativePolicy(), tau);
}
virtual void UpdatePolicy(IPolicy<Ctx>^ newPolicy)
{
this->defaultPolicy = newPolicy;
}
virtual void EnableExplore(bool explore)
{
m_explorer->Enable_Explore(explore);
}
~TauFirstExplorer()
{
delete m_explorer;
}
internal:
virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
{
return defaultPolicy->ChooseAction(context);
}
NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* Get()
{
return m_explorer;
}
private:
IPolicy<Ctx>^ defaultPolicy;
NativeMultiWorldTesting::TauFirstExplorer<NativeContext>* m_explorer;
};
/// <summary>
/// The epsilon greedy exploration class.
/// </summary>
/// <remarks>
/// In some cases, different actions have a different scores, and you
/// would prefer to choose actions with large scores. Softmax allows
/// you to do that.
/// </remarks>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class SoftmaxExplorer : public IExplorer<Ctx>, public IConsumeScorer<Ctx>, public ScorerCallback<Ctx>
{
public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultScorer">A function which outputs a score for each action.</param>
/// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda, UInt32 numActions)
{
this->defaultScorer = defaultScorer;
m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda, (u32)numActions);
}
/// <summary>
/// Initializes a softmax explorer with variable number of actions.
/// </summary>
/// <param name="defaultScorer">A function which outputs a score for each action.</param>
/// <param name="lambda">lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.</param>
SoftmaxExplorer(IScorer<Ctx>^ defaultScorer, float lambda)
{
if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
{
throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
}
this->defaultScorer = defaultScorer;
m_explorer = new NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>(*GetNativeScorer(), lambda);
}
virtual void UpdateScorer(IScorer<Ctx>^ newScorer)
{
this->defaultScorer = newScorer;
}
virtual void EnableExplore(bool explore)
{
m_explorer->Enable_Explore(explore);
}
~SoftmaxExplorer()
{
delete m_explorer;
}
internal:
virtual List<float>^ InvokeScorerCallback(Ctx context) override
{
return defaultScorer->ScoreActions(context);
}
NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* Get()
{
return m_explorer;
}
private:
IScorer<Ctx>^ defaultScorer;
NativeMultiWorldTesting::SoftmaxExplorer<NativeContext>* m_explorer;
};
/// <summary>
/// The generic exploration class.
/// </summary>
/// <remarks>
/// GenericExplorer provides complete flexibility. You can create any
/// distribution over actions desired, and it will draw from that.
/// </remarks>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class GenericExplorer : public IExplorer<Ctx>, public IConsumeScorer<Ctx>, public ScorerCallback<Ctx>
{
public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultScorer">A function which outputs the probability of each action.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
GenericExplorer(IScorer<Ctx>^ defaultScorer, UInt32 numActions)
{
this->defaultScorer = defaultScorer;
m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer(), (u32)numActions);
}
/// <summary>
/// Initializes a generic explorer with variable number of actions.
/// </summary>
/// <param name="defaultScorer">A function which outputs the probability of each action.</param>
GenericExplorer(IScorer<Ctx>^ defaultScorer)
{
if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
{
throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
}
this->defaultScorer = defaultScorer;
m_explorer = new NativeMultiWorldTesting::GenericExplorer<NativeContext>(*GetNativeScorer());
}
virtual void UpdateScorer(IScorer<Ctx>^ newScorer)
{
this->defaultScorer = newScorer;
}
virtual void EnableExplore(bool explore)
{
m_explorer->Enable_Explore(explore);
}
~GenericExplorer()
{
delete m_explorer;
}
internal:
virtual List<float>^ InvokeScorerCallback(Ctx context) override
{
return defaultScorer->ScoreActions(context);
}
NativeMultiWorldTesting::GenericExplorer<NativeContext>* Get()
{
return m_explorer;
}
private:
IScorer<Ctx>^ defaultScorer;
NativeMultiWorldTesting::GenericExplorer<NativeContext>* m_explorer;
};
/// <summary>
/// The bootstrap exploration class.
/// </summary>
/// <remarks>
/// The Bootstrap explorer randomizes over the actions chosen by a set of
/// default policies. This performs well statistically but can be
/// computationally expensive.
/// </remarks>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class BootstrapExplorer : public IExplorer<Ctx>, public IConsumePolicies<Ctx>, public PolicyCallback<Ctx>
{
public:
/// <summary>
/// The constructor is the only public member, because this should be used with the MwtExplorer.
/// </summary>
/// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
/// <param name="numActions">The number of actions to randomize over.</param>
BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies, UInt32 numActions)
{
this->defaultPolicies = defaultPolicies;
if (this->defaultPolicies == nullptr)
{
throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
}
m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length), (u32)numActions);
}
/// <summary>
/// Initializes a bootstrap explorer with variable number of actions.
/// </summary>
/// <param name="defaultPolicies">A set of default policies to be uniform random over.</param>
BootstrapExplorer(cli::array<IPolicy<Ctx>^>^ defaultPolicies)
{
if (!(IVariableActionContext::typeid->IsAssignableFrom(Ctx::typeid)))
{
throw gcnew ArgumentException("The specified context type does not implement variable-action interface.", "Ctx");
}
this->defaultPolicies = defaultPolicies;
if (this->defaultPolicies == nullptr)
{
throw gcnew ArgumentNullException("The specified array of default policy functions cannot be null.");
}
m_explorer = new NativeMultiWorldTesting::BootstrapExplorer<NativeContext>(*GetNativePolicies((u32)defaultPolicies->Length));
}
virtual void UpdatePolicy(cli::array<IPolicy<Ctx>^>^ newPolicies)
{
this->defaultPolicies = newPolicies;
}
virtual void EnableExplore(bool explore)
{
m_explorer->Enable_Explore(explore);
}
~BootstrapExplorer()
{
delete m_explorer;
}
internal:
virtual UInt32 InvokePolicyCallback(Ctx context, int index) override
{
if (index < 0 || index >= defaultPolicies->Length)
{
throw gcnew InvalidDataException("Internal error: Index of interop bag is out of range.");
}
return defaultPolicies[index]->ChooseAction(context);
}
NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* Get()
{
return m_explorer;
}
private:
cli::array<IPolicy<Ctx>^>^ defaultPolicies;
NativeMultiWorldTesting::BootstrapExplorer<NativeContext>* m_explorer;
};
/// <summary>
/// The top level MwtExplorer class. Using this makes sure that the
/// right bits are recorded and good random actions are chosen.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public ref class MwtExplorer : public RecorderCallback<Ctx>
{
public:
/// <summary>
/// Constructor.
/// </summary>
/// <param name="appId">This should be unique to each experiment to avoid correlation bugs.</param>
/// <param name="recorder">A user-specified class for recording the appropriate bits for use in evaluation and learning.</param>
MwtExplorer(String^ appId, IRecorder<Ctx>^ recorder)
{
this->appId = appId;
this->recorder = recorder;
}
/// <summary>
/// Choose_Action should be drop-in replacement for any existing policy function.
/// </summary>
/// <param name="explorer">An existing exploration algorithm (one of the above) which uses the default policy as a callback.</param>
/// <param name="unique_key">A unique identifier for the experimental unit. This could be a user id, a session id, etc...</param>
/// <param name="context">The context upon which a decision is made. See SimpleContext above for an example.</param>
/// <returns>An unsigned 32-bit integer representing the 1-based chosen action.</returns>
UInt32 ChooseAction(IExplorer<Ctx>^ explorer, String^ unique_key, Ctx context)
{
String^ salt = this->appId;
NativeMultiWorldTesting::MwtExplorer<NativeContext> mwt(marshal_as<std::string>(salt), *GetNativeRecorder());
// Normal handles are sufficient here since native code will only hold references and not access the object's data
// https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
GCHandle selfHandle = GCHandle::Alloc(this);
IntPtr selfPtr = (IntPtr)selfHandle;
GCHandle contextHandle = GCHandle::Alloc(context);
IntPtr contextPtr = (IntPtr)contextHandle;
GCHandle explorerHandle = GCHandle::Alloc(explorer);
IntPtr explorerPtr = (IntPtr)explorerHandle;
try
{
NativeContext native_context(selfPtr.ToPointer(), explorerPtr.ToPointer(), contextPtr.ToPointer(),
this->GetNumActionsCallback());
u32 action = 0;
if (explorer->GetType() == EpsilonGreedyExplorer<Ctx>::typeid)
{
EpsilonGreedyExplorer<Ctx>^ epsilonGreedyExplorer = (EpsilonGreedyExplorer<Ctx>^)explorer;
action = mwt.Choose_Action(*epsilonGreedyExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
}
else if (explorer->GetType() == TauFirstExplorer<Ctx>::typeid)
{
TauFirstExplorer<Ctx>^ tauFirstExplorer = (TauFirstExplorer<Ctx>^)explorer;
action = mwt.Choose_Action(*tauFirstExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
}
else if (explorer->GetType() == SoftmaxExplorer<Ctx>::typeid)
{
SoftmaxExplorer<Ctx>^ softmaxExplorer = (SoftmaxExplorer<Ctx>^)explorer;
action = mwt.Choose_Action(*softmaxExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
}
else if (explorer->GetType() == GenericExplorer<Ctx>::typeid)
{
GenericExplorer<Ctx>^ genericExplorer = (GenericExplorer<Ctx>^)explorer;
action = mwt.Choose_Action(*genericExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
}
else if (explorer->GetType() == BootstrapExplorer<Ctx>::typeid)
{
BootstrapExplorer<Ctx>^ bootstrapExplorer = (BootstrapExplorer<Ctx>^)explorer;
action = mwt.Choose_Action(*bootstrapExplorer->Get(), marshal_as<std::string>(unique_key), native_context);
}
return action;
}
finally
{
if (explorerHandle.IsAllocated)
{
explorerHandle.Free();
}
if (contextHandle.IsAllocated)
{
contextHandle.Free();
}
if (selfHandle.IsAllocated)
{
selfHandle.Free();
}
}
}
internal:
virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) override
{
recorder->Record(context, action, probability, unique_key);
}
private:
IRecorder<Ctx>^ recorder;
String^ appId;
};
/// <summary>
/// Represents a feature in a sparse array.
/// </summary>
[StructLayout(LayoutKind::Sequential)]
public value struct Feature
{
float Value;
UInt32 Id;
};
/// <summary>
/// A sample recorder class that converts the exploration tuple into string format.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx> where Ctx : IStringContext
public ref class StringRecorder : public IRecorder<Ctx>, public ToStringCallback<Ctx>
{
public:
StringRecorder()
{
m_string_recorder = new NativeMultiWorldTesting::StringRecorder<NativeStringContext>();
}
~StringRecorder()
{
delete m_string_recorder;
}
virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey)
{
// Normal handles are sufficient here since native code will only hold references and not access the object's data
// https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
GCHandle contextHandle = GCHandle::Alloc(context);
IntPtr contextPtr = (IntPtr)contextHandle;
NativeStringContext native_context(contextPtr.ToPointer(), GetCallback());
m_string_recorder->Record(native_context, (u32)action, probability, marshal_as<string>(uniqueKey));
contextHandle.Free();
}
/// <summary>
/// Gets the content of the recording so far as a string and clears internal content.
/// </summary>
/// <returns>
/// A string with recording content.
/// </returns>
String^ GetRecording()
{
// Workaround for C++-CLI bug which does not allow default value for parameter
return GetRecording(true);
}
/// <summary>
/// Gets the content of the recording so far as a string and optionally clears internal content.
/// </summary>
/// <param name="flush">A boolean value indicating whether to clear the internal content.</param>
/// <returns>
/// A string with recording content.
/// </returns>
String^ GetRecording(bool flush)
{
return gcnew String(m_string_recorder->Get_Recording(flush).c_str());
}
private:
NativeMultiWorldTesting::StringRecorder<NativeStringContext>* m_string_recorder;
};
/// <summary>
/// A sample context class that stores a vector of Features.
/// </summary>
public ref class SimpleContext : public IStringContext
{
public:
SimpleContext(cli::array<Feature>^ features)
{
Features = features;
// TODO: add another constructor overload for native SimpleContext to avoid copying feature values
m_features = new vector<NativeMultiWorldTesting::Feature>();
for (int i = 0; i < features->Length; i++)
{
m_features->push_back({ features[i].Value, features[i].Id });
}
m_native_context = new NativeMultiWorldTesting::SimpleContext(*m_features);
}
String^ ToString() override
{
return gcnew String(m_native_context->To_String().c_str());
}
~SimpleContext()
{
delete m_native_context;
}
public:
cli::array<Feature>^ GetFeatures() { return Features; }
internal:
cli::array<Feature>^ Features;
private:
vector<NativeMultiWorldTesting::Feature>* m_features;
NativeMultiWorldTesting::SimpleContext* m_native_context;
};
}
/*! @} End of Doxygen Groups*/

Просмотреть файл

@ -1,128 +0,0 @@
#pragma once
using namespace System;
using namespace System::Collections::Generic;
/** \defgroup MultiWorldTestingCsharp
\brief C# implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
*/
/*!
* \addtogroup MultiWorldTestingCsharp
* @{
*/
//! Interface for C# version of Multiworld Testing library.
//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/cs_test/ExploreOnlySample.cs
namespace MultiWorldTesting {
/// <summary>
/// Represents a recorder that exposes a method to record exploration data based on generic contexts.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
/// <remarks>
/// Exploration data is specified as a set of tuples (context, action, probability, key) as described below. An
/// application passes an IRecorder object to the @MwtExplorer constructor. See
/// @StringRecorder for a sample IRecorder object.
/// </remarks>
generic <class Ctx>
public interface class IRecorder
{
public:
/// <summary>
/// Records the exploration data associated with a given decision.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <param name="action">Chosen by an exploration algorithm given context.</param>
/// <param name="probability">The probability of the chosen action given context.</param>
/// <param name="uniqueKey">A user-defined identifer for the decision.</param>
virtual void Record(Ctx context, UInt32 action, float probability, String^ uniqueKey) = 0;
};
/// <summary>
/// Exposes a method for choosing an action given a generic context. IPolicy objects are
/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public interface class IPolicy
{
public:
/// <summary>
/// Determines the action to take for a given context.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <returns>Index of the action to take (1-based)</returns>
virtual UInt32 ChooseAction(Ctx context) = 0;
};
/// <summary>
/// Exposes a method for specifying a score (weight) for each action given a generic context.
/// </summary>
/// <typeparam name="Ctx">The Context type.</typeparam>
generic <class Ctx>
public interface class IScorer
{
public:
/// <summary>
/// Determines the score of each action for a given context.
/// This implementation should be thread-safe if multithreading is needed.
/// </summary>
/// <param name="context">A user-defined context for the decision.</param>
/// <returns>Vector of scores indexed by action (1-based).</returns>
virtual List<float>^ ScoreActions(Ctx context) = 0;
};
/// <summary>
/// Represents a context interface with variable number of actions which is
/// enforced if exploration algorithm is initialized in variable number of actions mode.
/// </summary>
public interface class IVariableActionContext
{
public:
/// <summary>
/// Gets the number of actions for the current context.
/// </summary>
/// <returns>The number of actions available for the current context.</returns>
virtual UInt32 GetNumberOfActions() = 0;
};
generic <class Ctx>
public interface class IExplorer
{
public:
virtual void EnableExplore(bool explore) = 0;
};
generic <class Ctx>
public interface class IConsumePolicy
{
public:
virtual void UpdatePolicy(IPolicy<Ctx>^ newPolicy) = 0;
};
generic <class Ctx>
public interface class IConsumePolicies
{
public:
virtual void UpdatePolicy(cli::array<IPolicy<Ctx>^>^ newPolicies) = 0;
};
generic <class Ctx>
public interface class IConsumeScorer
{
public:
virtual void UpdateScorer(IScorer<Ctx>^ newScorer) = 0;
};
public interface class IStringContext
{
public:
virtual String^ ToString() = 0;
};
}
/*! @} End of Doxygen Groups*/

Просмотреть файл

@ -1,416 +0,0 @@
#pragma once
#define MANAGED_CODE
#include "explore_interface.h"
#include "MWTExplorer.h"
#include <msclr\marshal_cppstd.h>
using namespace System;
using namespace System::Collections::Generic;
using namespace System::IO;
using namespace System::Runtime::InteropServices;
using namespace System::Xml::Serialization;
using namespace msclr::interop;
namespace MultiWorldTesting {
// Context callback
private delegate UInt32 ClrContextGetNumActionsCallback(IntPtr contextPtr);
typedef u32 Native_Context_Get_Num_Actions_Callback(void* context);
// Policy callback
private delegate UInt32 ClrPolicyCallback(IntPtr explorerPtr, IntPtr contextPtr, int index);
typedef u32 Native_Policy_Callback(void* explorer, void* context, int index);
// Scorer callback
private delegate void ClrScorerCallback(IntPtr explorerPtr, IntPtr contextPtr, IntPtr scores, IntPtr size);
typedef void Native_Scorer_Callback(void* explorer, void* context, float* scores[], u32* size);
// Recorder callback
private delegate void ClrRecorderCallback(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKey);
typedef void Native_Recorder_Callback(void* mwt, void* context, u32 action, float probability, void* unique_key);
// ToString callback
private delegate void ClrToStringCallback(IntPtr contextPtr, IntPtr stringValue);
typedef void Native_To_String_Callback(void* explorer, void* string_value);
// NativeContext travels through interop space and contains instances of Mwt, Explorer, Context
// used for triggering callback for Policy, Scorer, Recorder
class NativeContext : public NativeMultiWorldTesting::IVariableActionContext
{
public:
NativeContext(void* clr_mwt, void* clr_explorer, void* clr_context,
Native_Context_Get_Num_Actions_Callback* callback_num_actions)
{
m_clr_mwt = clr_mwt;
m_clr_explorer = clr_explorer;
m_clr_context = clr_context;
m_callback_num_actions = callback_num_actions;
}
u32 Get_Number_Of_Actions()
{
return m_callback_num_actions(m_clr_context);
}
void* Get_Clr_Mwt()
{
return m_clr_mwt;
}
void* Get_Clr_Context()
{
return m_clr_context;
}
void* Get_Clr_Explorer()
{
return m_clr_explorer;
}
private:
void* m_clr_mwt;
void* m_clr_context;
void* m_clr_explorer;
private:
Native_Context_Get_Num_Actions_Callback* m_callback_num_actions;
};
class NativeStringContext
{
public:
NativeStringContext(void* clr_context, Native_To_String_Callback* func) : m_func(func)
{
m_clr_context = clr_context;
}
string To_String()
{
string value;
m_func(m_clr_context, &value);
return value;
}
private:
void* m_clr_context;
Native_To_String_Callback* const m_func;
};
// NativeRecorder listens to callback event and reroute it to the managed Recorder instance
class NativeRecorder : public NativeMultiWorldTesting::IRecorder<NativeContext>
{
public:
NativeRecorder(Native_Recorder_Callback* native_func) : m_func(native_func)
{
}
void Record(NativeContext& context, u32 action, float probability, string unique_key)
{
// Normal handles are sufficient here since native code will only hold references and not access the object's data
// https://www.microsoftpressstore.com/articles/article.aspx?p=2224054&seqNum=4
GCHandle uniqueKeyHandle = GCHandle::Alloc(gcnew String(unique_key.c_str()));
try
{
IntPtr uniqueKeyPtr = (IntPtr)uniqueKeyHandle;
m_func(context.Get_Clr_Mwt(), context.Get_Clr_Context(), action, probability, uniqueKeyPtr.ToPointer());
}
finally
{
if (uniqueKeyHandle.IsAllocated)
{
uniqueKeyHandle.Free();
}
}
}
private:
Native_Recorder_Callback* const m_func;
};
// NativePolicy listens to callback event and reroute it to the managed Policy instance
class NativePolicy : public NativeMultiWorldTesting::IPolicy<NativeContext>
{
public:
NativePolicy(Native_Policy_Callback* func, int index = -1) : m_func(func)
{
m_index = index;
}
u32 Choose_Action(NativeContext& context)
{
return m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), m_index);
}
private:
Native_Policy_Callback* const m_func;
int m_index;
};
class NativeScorer : public NativeMultiWorldTesting::IScorer<NativeContext>
{
public:
NativeScorer(Native_Scorer_Callback* func) : m_func(func)
{
}
vector<float> Score_Actions(NativeContext& context)
{
float* scores = nullptr;
u32 num_scores = 0;
try
{
m_func(context.Get_Clr_Explorer(), context.Get_Clr_Context(), &scores, &num_scores);
// It's ok if scores is null, vector will be empty
vector<float> scores_vector(scores, scores + num_scores);
return scores_vector;
}
finally
{
delete[] scores;
}
}
private:
Native_Scorer_Callback* const m_func;
};
// Triggers callback to the Context instance
generic <class Ctx>
public ref class ContextCallback
{
internal:
ContextCallback()
{
contextNumActionsCallback = gcnew ClrContextGetNumActionsCallback(&ContextCallback<Ctx>::InteropInvokeNumActions);
IntPtr contextNumActionsCallbackPtr = Marshal::GetFunctionPointerForDelegate(contextNumActionsCallback);
m_num_actions_callback = static_cast<Native_Context_Get_Num_Actions_Callback*>(contextNumActionsCallbackPtr.ToPointer());
}
Native_Context_Get_Num_Actions_Callback* GetNumActionsCallback()
{
return m_num_actions_callback;
}
static UInt32 InteropInvokeNumActions(IntPtr contextPtr)
{
GCHandle contextHandle = (GCHandle)contextPtr;
return ((IVariableActionContext^)contextHandle.Target)->GetNumberOfActions();
}
private:
initonly ClrContextGetNumActionsCallback^ contextNumActionsCallback;
private:
Native_Context_Get_Num_Actions_Callback* m_num_actions_callback;
};
// Triggers callback to the Policy instance to choose an action
generic <class Ctx>
public ref class PolicyCallback abstract
{
internal:
virtual UInt32 InvokePolicyCallback(Ctx context, int index) = 0;
PolicyCallback()
{
policyCallback = gcnew ClrPolicyCallback(&PolicyCallback<Ctx>::InteropInvoke);
IntPtr policyCallbackPtr = Marshal::GetFunctionPointerForDelegate(policyCallback);
m_callback = static_cast<Native_Policy_Callback*>(policyCallbackPtr.ToPointer());
m_native_policy = nullptr;
m_native_policies = nullptr;
}
~PolicyCallback()
{
delete m_native_policy;
delete m_native_policies;
}
NativePolicy* GetNativePolicy()
{
if (m_native_policy == nullptr)
{
m_native_policy = new NativePolicy(m_callback);
}
return m_native_policy;
}
vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* GetNativePolicies(int count)
{
if (m_native_policies == nullptr)
{
m_native_policies = new vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>();
for (int i = 0; i < count; i++)
{
m_native_policies->push_back(unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>(new NativePolicy(m_callback, i)));
}
}
return m_native_policies;
}
static UInt32 InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, int index)
{
GCHandle callbackHandle = (GCHandle)callbackPtr;
PolicyCallback<Ctx>^ callback = (PolicyCallback<Ctx>^)callbackHandle.Target;
GCHandle contextHandle = (GCHandle)contextPtr;
Ctx context = (Ctx)contextHandle.Target;
return callback->InvokePolicyCallback(context, index);
}
private:
initonly ClrPolicyCallback^ policyCallback;
private:
NativePolicy* m_native_policy;
vector<unique_ptr<NativeMultiWorldTesting::IPolicy<NativeContext>>>* m_native_policies;
Native_Policy_Callback* m_callback;
};
// Triggers callback to the Recorder instance to record interaction data
generic <class Ctx>
public ref class RecorderCallback abstract : public ContextCallback<Ctx>
{
internal:
virtual void InvokeRecorderCallback(Ctx context, UInt32 action, float probability, String^ unique_key) = 0;
RecorderCallback()
{
recorderCallback = gcnew ClrRecorderCallback(&RecorderCallback<Ctx>::InteropInvoke);
IntPtr recorderCallbackPtr = Marshal::GetFunctionPointerForDelegate(recorderCallback);
Native_Recorder_Callback* callback = static_cast<Native_Recorder_Callback*>(recorderCallbackPtr.ToPointer());
m_native_recorder = new NativeRecorder(callback);
}
~RecorderCallback()
{
delete m_native_recorder;
}
NativeRecorder* GetNativeRecorder()
{
return m_native_recorder;
}
static void InteropInvoke(IntPtr mwtPtr, IntPtr contextPtr, UInt32 action, float probability, IntPtr uniqueKeyPtr)
{
GCHandle mwtHandle = (GCHandle)mwtPtr;
RecorderCallback<Ctx>^ callback = (RecorderCallback<Ctx>^)mwtHandle.Target;
GCHandle contextHandle = (GCHandle)contextPtr;
Ctx context = (Ctx)contextHandle.Target;
GCHandle uniqueKeyHandle = (GCHandle)uniqueKeyPtr;
String^ uniqueKey = (String^)uniqueKeyHandle.Target;
callback->InvokeRecorderCallback(context, action, probability, uniqueKey);
}
private:
initonly ClrRecorderCallback^ recorderCallback;
private:
NativeRecorder* m_native_recorder;
};
// Triggers callback to the Recorder instance to record interaction data
generic <class Ctx>
public ref class ScorerCallback abstract
{
internal:
virtual List<float>^ InvokeScorerCallback(Ctx context) = 0;
ScorerCallback()
{
scorerCallback = gcnew ClrScorerCallback(&ScorerCallback<Ctx>::InteropInvoke);
IntPtr scorerCallbackPtr = Marshal::GetFunctionPointerForDelegate(scorerCallback);
Native_Scorer_Callback* callback = static_cast<Native_Scorer_Callback*>(scorerCallbackPtr.ToPointer());
m_native_scorer = new NativeScorer(callback);
}
~ScorerCallback()
{
delete m_native_scorer;
}
NativeScorer* GetNativeScorer()
{
return m_native_scorer;
}
static void InteropInvoke(IntPtr callbackPtr, IntPtr contextPtr, IntPtr scoresPtr, IntPtr sizePtr)
{
GCHandle callbackHandle = (GCHandle)callbackPtr;
ScorerCallback<Ctx>^ callback = (ScorerCallback<Ctx>^)callbackHandle.Target;
GCHandle contextHandle = (GCHandle)contextPtr;
Ctx context = (Ctx)contextHandle.Target;
List<float>^ scoreList = callback->InvokeScorerCallback(context);
if (scoreList == nullptr || scoreList->Count == 0)
{
return;
}
u32* num_scores = (u32*)sizePtr.ToPointer();
*num_scores = (u32)scoreList->Count;
float* scores = new float[*num_scores];
for (u32 i = 0; i < *num_scores; i++)
{
scores[i] = scoreList[i];
}
float** native_scores = (float**)scoresPtr.ToPointer();
*native_scores = scores;
}
private:
initonly ClrScorerCallback^ scorerCallback;
private:
NativeScorer* m_native_scorer;
};
// Triggers callback to the Context instance to perform ToString() operation
generic <class Ctx> where Ctx : IStringContext
public ref class ToStringCallback
{
internal:
ToStringCallback()
{
toStringCallback = gcnew ClrToStringCallback(&ToStringCallback<Ctx>::InteropInvoke);
IntPtr toStringCallbackPtr = Marshal::GetFunctionPointerForDelegate(toStringCallback);
m_callback = static_cast<Native_To_String_Callback*>(toStringCallbackPtr.ToPointer());
}
Native_To_String_Callback* GetCallback()
{
return m_callback;
}
static void InteropInvoke(IntPtr contextPtr, IntPtr stringPtr)
{
GCHandle contextHandle = (GCHandle)contextPtr;
Ctx context = (Ctx)contextHandle.Target;
string* out_string = (string*)stringPtr.ToPointer();
*out_string = marshal_as<string>(context->ToString());
}
private:
initonly ClrToStringCallback^ toStringCallback;
private:
Native_To_String_Callback* m_callback;
};
}

Просмотреть файл

@ -1,73 +0,0 @@
// explore.cpp : Timing code to measure performance of MWT Explorer library
#include "MWTExplorer.h"
#include <chrono>
#include <tuple>
#include <iostream>
using namespace std;
using namespace std::chrono;
using namespace MultiWorldTesting;
class MySimplePolicy : public IPolicy<SimpleContext>
{
public:
u32 Choose_Action(SimpleContext& context)
{
return (u32)1;
}
};
const u32 num_actions = 10;
void Clock_Explore()
{
float epsilon = .2f;
string unique_key = "key";
int num_features = 1000;
int num_iter = 10000;
int num_warmup = 100;
int num_interactions = 1;
// pre-create features
vector<Feature> features;
for (int i = 0; i < num_features; i++)
{
Feature f = {0.5, i+1};
features.push_back(f);
}
long long time_init = 0, time_choose = 0;
for (int iter = 0; iter < num_iter + num_warmup; iter++)
{
high_resolution_clock::time_point t1 = high_resolution_clock::now();
StringRecorder<SimpleContext> recorder;
MwtExplorer<SimpleContext> mwt("test", recorder);
MySimplePolicy default_policy;
EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
high_resolution_clock::time_point t2 = high_resolution_clock::now();
time_init += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
t1 = high_resolution_clock::now();
SimpleContext appContext(features);
for (int i = 0; i < num_interactions; i++)
{
mwt.Choose_Action(explorer, unique_key, appContext);
}
t2 = high_resolution_clock::now();
time_choose += iter < num_warmup ? 0 : duration_cast<chrono::microseconds>(t2 - t1).count();
}
cout << "# iterations: " << num_iter << ", # interactions: " << num_interactions << ", # context features: " << num_features << endl;
cout << "--- PER ITERATION ---" << endl;
cout << "Init: " << (double)time_init / num_iter << " micro" << endl;
cout << "Choose Action: " << (double)time_choose / (num_iter * num_interactions) << " micro" << endl;
cout << "--- TOTAL TIME ---: " << (time_init + time_choose) << " micro" << endl;
}
int main(int argc, char* argv[])
{
Clock_Explore();
return 0;
}

Просмотреть файл

@ -1,76 +1,32 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 2013
VisualStudioVersion = 12.0.30723.0
VisualStudioVersion = 12.0.31101.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore", "explore.vcxproj", "{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreTests", "Test\ExploreTests.csproj", "{CB0C6B20-560C-4822-8EF6-DA999A64B542}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_static", "static\explore_static.vcxproj", "{ACE47E98-488C-4CDF-B9F1-36337B2855AD}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreSample", "Sample\ExploreSample.csproj", "{7081D542-AE64-485D-9087-79194B958499}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_clr", "clr\explore_clr.vcxproj", "{8400DA16-1F46-4A31-A126-BBE16F62BFD7}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "explore_tests", "tests\explore_tests.vcxproj", "{5AE3AA40-BEB0-4979-8166-3B885172C430}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreTests", "tests\ExploreTests.csproj", "{CB0C6B20-560C-4822-8EF6-DA999A64B542}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ExploreSample", "ExploreSample\ExploreSample.csproj", "{7081D542-AE64-485D-9087-79194B958499}"
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Explore", "Explore\Explore.csproj", "{6D245816-6016-49B6-9E37-A0BF0D2A736A}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Win32 = Debug|Win32
Debug|x64 = Debug|x64
Release|Win32 = Release|Win32
Release|x64 = Release|x64
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.ActiveCfg = Debug|Win32
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|Win32.Build.0 = Debug|Win32
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.ActiveCfg = Debug|x64
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Debug|x64.Build.0 = Debug|x64
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.ActiveCfg = Release|Win32
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|Win32.Build.0 = Release|Win32
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.ActiveCfg = Release|x64
{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}.Release|x64.Build.0 = Release|x64
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.ActiveCfg = Debug|Win32
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|Win32.Build.0 = Debug|Win32
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.ActiveCfg = Debug|x64
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Debug|x64.Build.0 = Debug|x64
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.ActiveCfg = Release|Win32
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|Win32.Build.0 = Release|Win32
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.ActiveCfg = Release|x64
{ACE47E98-488C-4CDF-B9F1-36337B2855AD}.Release|x64.Build.0 = Release|x64
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.ActiveCfg = Debug|Win32
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|Win32.Build.0 = Debug|Win32
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.ActiveCfg = Debug|x64
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Debug|x64.Build.0 = Debug|x64
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.ActiveCfg = Release|Win32
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|Win32.Build.0 = Release|Win32
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.ActiveCfg = Release|x64
{8400DA16-1F46-4A31-A126-BBE16F62BFD7}.Release|x64.Build.0 = Release|x64
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.ActiveCfg = Debug|Win32
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|Win32.Build.0 = Debug|Win32
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.ActiveCfg = Debug|x64
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Debug|x64.Build.0 = Debug|x64
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.ActiveCfg = Release|Win32
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|Win32.Build.0 = Release|Win32
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.ActiveCfg = Release|x64
{5AE3AA40-BEB0-4979-8166-3B885172C430}.Release|x64.Build.0 = Release|x64
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.ActiveCfg = Debug|x86
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Win32.Build.0 = Debug|x86
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.ActiveCfg = Debug|x64
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|x64.Build.0 = Debug|x64
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.ActiveCfg = Release|x86
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Win32.Build.0 = Release|x86
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.ActiveCfg = Release|x64
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|x64.Build.0 = Release|x64
{7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.ActiveCfg = Debug|x86
{7081D542-AE64-485D-9087-79194B958499}.Debug|Win32.Build.0 = Debug|x86
{7081D542-AE64-485D-9087-79194B958499}.Debug|x64.ActiveCfg = Debug|x64
{7081D542-AE64-485D-9087-79194B958499}.Debug|x64.Build.0 = Debug|x64
{7081D542-AE64-485D-9087-79194B958499}.Release|Win32.ActiveCfg = Release|x86
{7081D542-AE64-485D-9087-79194B958499}.Release|Win32.Build.0 = Release|x86
{7081D542-AE64-485D-9087-79194B958499}.Release|x64.ActiveCfg = Release|x64
{7081D542-AE64-485D-9087-79194B958499}.Release|x64.Build.0 = Release|x64
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Debug|Any CPU.Build.0 = Debug|Any CPU
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Any CPU.ActiveCfg = Release|Any CPU
{CB0C6B20-560C-4822-8EF6-DA999A64B542}.Release|Any CPU.Build.0 = Release|Any CPU
{7081D542-AE64-485D-9087-79194B958499}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7081D542-AE64-485D-9087-79194B958499}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7081D542-AE64-485D-9087-79194B958499}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7081D542-AE64-485D-9087-79194B958499}.Release|Any CPU.Build.0 = Release|Any CPU
{6D245816-6016-49B6-9E37-A0BF0D2A736A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6D245816-6016-49B6-9E37-A0BF0D2A736A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6D245816-6016-49B6-9E37-A0BF0D2A736A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6D245816-6016-49B6-9E37-A0BF0D2A736A}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

Просмотреть файл

@ -1,174 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{FFC1BEDC-8D26-4456-93D8-F0ED091E6CE4}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>vw_explore</RootNamespace>
<ProjectName>explore</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Platform)'=='x64'">
<BoostIncludeDir>c:\boost\x64\include\boost-1_56</BoostIncludeDir>
<BoostLibDir>c:\boost\x64\lib</BoostLibDir>
<ZlibIncludeDir>..\..\zlib-1.2.8</ZlibIncludeDir>
<ZlibLibDir>$(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration)</ZlibLibDir>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>static;</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>static;</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>static;</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<AdditionalIncludeDirectories>static;</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<Text Include="ReadMe.txt" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="explore.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="explore_sample.cpp" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="static\explore_static.vcxproj">
<Project>{ace47e98-488c-4cdf-b9f1-36337b2855ad}</Project>
</ProjectReference>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -1,28 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<Text Include="ReadMe.txt" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="explore.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="explore_sample.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -1,212 +0,0 @@
// vw_explore.cpp : Defines the entry point for the console application.
//
#include "MWTExplorer.h"
#include <chrono>
#include <tuple>
#include <iostream>
using namespace std;
using namespace std::chrono;
using namespace MultiWorldTesting;
/// Example of a custom context.
class MyContext
{
};
/// Example of a custom policy which implements the IPolicy<MyContext>,
/// declaring that this policy only interacts with MyContext objects.
class MyPolicy : public IPolicy<MyContext>
{
public:
u32 Choose_Action(MyContext& context)
{
// Always returns the same action regardless of context
return (u32)1;
}
};
/// Example of a custom policy which implements the IPolicy<SimpleContext>,
/// declaring that this policy only interacts with SimpleContext objects.
class MySimplePolicy : public IPolicy<SimpleContext>
{
public:
u32 Choose_Action(SimpleContext& context)
{
// Always returns the same action regardless of context
return (u32)1;
}
};
/// Example of a custom scorer which implements the IScorer<MyContext>,
/// declaring that this scorer only interacts with MyContext objects.
class MyScorer : public IScorer<MyContext>
{
public:
MyScorer(u32 num_actions) : m_num_actions(num_actions)
{
}
vector<float> Score_Actions(MyContext& context)
{
vector<float> scores;
for (size_t i = 0; i < m_num_actions; i++)
{
// Gives every action the same score (which results in a uniform distribution).
scores.push_back(.1f);
}
return scores;
}
private:
u32 m_num_actions;
};
///
/// Represents a tuple <context, action, probability, key>.
///
template <class Ctx>
struct MyInteraction
{
Ctx Context;
u32 Action;
float Probability;
string Unique_Key;
};
/// Example of a custom recorder which implements the IRecorder<MyContext>,
/// declaring that this recorder only interacts with MyContext objects.
class MyRecorder : public IRecorder<MyContext>
{
public:
virtual void Record(MyContext& context, u32 action, float probability, string unique_key)
{
// Stores the tuple internally in a vector that could be used later for other purposes.
m_interactions.push_back({ context, action, probability, unique_key });
}
private:
vector<MyInteraction<MyContext>> m_interactions;
};
int main(int argc, char* argv[])
{
if (argc < 2)
{
cerr << "arguments: {greedy,tau-first,bootstrap,softmax,generic}" << endl;
exit(1);
}
// Arguments for individual explorers
if (strcmp(argv[1], "greedy") == 0)
{
// Initialize Epsilon-Greedy explore algorithm using MyPolicy
// Creates a recorder of built-in StringRecorder type for string serialization
StringRecorder<SimpleContext> recorder;
// Creates a policy that interacts with SimpleContext type
MySimplePolicy default_policy;
// Creates an MwtExplorer instance using the recorder above
MwtExplorer<SimpleContext> mwt("appid", recorder);
u32 num_actions = 10;
float epsilon = .2f;
// Creates an Epsilon-Greedy explorer using the specified settings
EpsilonGreedyExplorer<SimpleContext> explorer(default_policy, epsilon, num_actions);
// Creates a context of built-in SimpleContext type
vector<Feature> features;
features.push_back({ 0.5f, 1 });
features.push_back({ 1.3f, 11 });
features.push_back({ -.95f, 413 });
SimpleContext context(features);
// Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
// using a sample string to uniquely identify this event
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, context);
cout << "Chosen action = " << action << endl;
cout << "Exploration record = " << recorder.Get_Recording();
}
else if (strcmp(argv[1], "tau-first") == 0)
{
// Initialize Tau-First explore algorithm using MyPolicy
MyRecorder recorder;
MwtExplorer<MyContext> mwt("appid", recorder);
int num_actions = 10;
u32 tau = 5;
MyPolicy default_policy;
TauFirstExplorer<MyContext> explorer(default_policy, tau, num_actions);
MyContext ctx;
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
cout << "action = " << action << endl;
}
else if (strcmp(argv[1], "bootstrap") == 0)
{
// Initialize Bootstrap explore algorithm using MyPolicy
MyRecorder recorder;
MwtExplorer<MyContext> mwt("appid", recorder);
u32 num_bags = 2;
// Create a vector of smart pointers to default policies using the built-in type PolicyPtr<Context>
vector<unique_ptr<IPolicy<MyContext>>> policy_functions;
for (size_t i = 0; i < num_bags; i++)
{
policy_functions.push_back(unique_ptr<IPolicy<MyContext>>(new MyPolicy()));
}
int num_actions = 10;
BootstrapExplorer<MyContext> explorer(policy_functions, num_actions);
MyContext ctx;
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
cout << "action = " << action << endl;
}
else if (strcmp(argv[1], "softmax") == 0)
{
// Initialize Softmax explore algorithm using MyScorer
MyRecorder recorder;
MwtExplorer<MyContext> mwt("salt", recorder);
u32 num_actions = 10;
MyScorer scorer(num_actions);
float lambda = 0.5f;
SoftmaxExplorer<MyContext> explorer(scorer, lambda, num_actions);
MyContext ctx;
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
cout << "action = " << action << endl;
}
else if (strcmp(argv[1], "generic") == 0)
{
// Initialize Generic explore algorithm using MyScorer
MyRecorder recorder;
MwtExplorer<MyContext> mwt("appid", recorder);
int num_actions = 10;
MyScorer scorer(num_actions);
GenericExplorer<MyContext> explorer(scorer, num_actions);
MyContext ctx;
string unique_key = "eventid";
u32 action = mwt.Choose_Action(explorer, unique_key, ctx);
cout << "action = " << action << endl;
}
else
{
cerr << "unknown exploration type: " << argv[1] << endl;
exit(1);
}
return 0;
}

Двоичные данные
mwt.chm

Двоичный файл не отображается.

Просмотреть файл

@ -1,884 +0,0 @@
//
// Main interface for clients of the Multiworld testing (MWT) service.
//
#pragma once
#include <stdexcept>
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include <vector>
#include <utility>
#include <memory>
#include <limits.h>
#include <tuple>
#ifdef MANAGED_CODE
#define PORTING_INTERFACE public
#define MWT_NAMESPACE namespace NativeMultiWorldTesting
#else
#define PORTING_INTERFACE private
#define MWT_NAMESPACE namespace MultiWorldTesting
#endif
using namespace std;
#include "utility.h"
/** \defgroup MultiWorldTestingCpp
\brief C++ implementation, for sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
*/
/*!
* \addtogroup MultiWorldTestingCpp
* @{
*/
//! Interface for C++ version of Multiworld Testing library.
//! For sample usage see: https://github.com/sidsen/vowpal_wabbit/blob/v0/explore/explore_sample.cpp
MWT_NAMESPACE {
// Forward declarations
template <class Ctx>
class IRecorder;
template <class Ctx>
class IExplorer;
///
/// The top-level MwtExplorer class. Using this enables principled and efficient exploration
/// over a set of possible actions, and ensures that the right bits are recorded.
///
template <class Ctx>
class MwtExplorer
{
public:
///
/// Constructor
///
/// @param appid This should be unique to your experiment or you risk nasty correlation bugs.
/// @param recorder A user-specified class for recording the appropriate bits for use in evaluation and learning.
///
MwtExplorer(std::string app_id, IRecorder<Ctx>& recorder) : m_recorder(recorder)
{
m_app_id = HashUtils::Compute_Id_Hash(app_id);
}
///
/// Chooses an action by invoking an underlying exploration algorithm. This should be a
/// drop-in replacement for any existing policy function.
///
/// @param explorer An existing exploration algorithm (one of the below) which uses the default policy as a callback.
/// @param unique_key A unique identifier for the experimental unit. This could be a user id, a session id, etc..
/// @param context The context upon which a decision is made. See SimpleContext below for an example.
///
u32 Choose_Action(IExplorer<Ctx>& explorer, string unique_key, Ctx& context)
{
u64 seed = HashUtils::Compute_Id_Hash(unique_key);
std::tuple<u32, float, bool> action_probability_log_tuple = explorer.Choose_Action(seed + m_app_id, context);
u32 action = std::get<0>(action_probability_log_tuple);
float prob = std::get<1>(action_probability_log_tuple);
if (std::get<2>(action_probability_log_tuple))
{
m_recorder.Record(context, action, prob, unique_key);
}
return action;
}
private:
u64 m_app_id;
IRecorder<Ctx>& m_recorder;
};
///
/// Exposes a method to record exploration data based on generic contexts. Exploration data
/// is specified as a set of tuples <context, action, probability, key> as described below. An
/// application passes an IRecorder object to the @MwtExplorer constructor. See
/// @StringRecorder for a sample IRecorder object.
///
template <class Ctx>
class IRecorder
{
public:
///
/// Records the exploration data associated with a given decision.
/// This implementation should be thread-safe if multithreading is needed.
///
/// @param context A user-defined context for the decision
/// @param action The action chosen by an exploration algorithm given context
/// @param probability The probability the exploration algorithm chose said action
/// @param unique_key A user-defined unique identifer for the decision
///
virtual void Record(Ctx& context, u32 action, float probability, string unique_key) = 0;
virtual ~IRecorder() { }
};
///
/// Exposes a method to choose an action given a generic context, and obtain the relevant
/// exploration bits. Invokes IPolicy::Choose_Action internally. Do not implement this
/// interface yourself: instead, use the various exploration algorithms below, which
/// implement it for you.
///
template <class Ctx>
class IExplorer
{
public:
///
/// Determines the action to take and the probability with which it was chosen, for a
/// given context.
///
/// @param salted_seed A PRG seed based on a unique id information provided by the user
/// @param context A user-defined context for the decision
/// @returns The action to take, the probability it was chosen, and a flag indicating
/// whether to record this decision
///
virtual std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context) = 0;
virtual void Enable_Explore(bool explore) = 0;
virtual ~IExplorer() { }
};
///
/// Exposes a method to choose an action given a generic context. IPolicy objects are
/// passed to (and invoked by) exploration algorithms to specify the default policy behavior.
///
template <class Ctx>
class IPolicy
{
public:
///
/// Determines the action to take for a given context.
/// This implementation should be thread-safe if multithreading is needed.
///
/// @param context A user-defined context for the decision
/// @returns The action to take (1-based index)
///
virtual u32 Choose_Action(Ctx& context) = 0;
virtual ~IPolicy() { }
};
///
/// Exposes a method for specifying a score (weight) for each action given a generic context.
///
template <class Ctx>
class IScorer
{
public:
///
/// Determines the score of each action for a given context.
/// This implementation should be thread-safe if multithreading is needed.
///
/// @param context A user-defined context for the decision
/// @returns A vector of scores indexed by action (1-based)
///
virtual vector<float> Score_Actions(Ctx& context) = 0;
virtual ~IScorer() { }
};
///
/// Represents a context interface with variable number of actions which is
/// enforced if exploration algorithm is initialized in variable number of actions mode.
///
class IVariableActionContext
{
public:
///
/// Gets the number of actions for the current context.
///
/// @returns The number of actions available for the current context.
///
virtual u32 Get_Number_Of_Actions() = 0;
virtual ~IVariableActionContext() { }
};
template <class Ctx>
class IConsumePolicy
{
public:
virtual void Update_Policy(IPolicy<Ctx>& new_policy) = 0;
virtual ~IConsumePolicy() { }
};
template <class Ctx>
class IConsumePolicies
{
public:
virtual void Update_Policy(vector<unique_ptr<IPolicy<Ctx>>>& new_policy_functions) = 0;
virtual ~IConsumePolicies() { }
};
template <class Ctx>
class IConsumeScorer
{
public:
virtual void Update_Scorer(IScorer<Ctx>& new_policy) = 0;
virtual ~IConsumeScorer() { }
};
///
/// A sample recorder class that converts the exploration tuple into string format.
///
template <class Ctx>
struct StringRecorder : public IRecorder<Ctx>
{
void Record(Ctx& context, u32 action, float probability, string unique_key)
{
// Implicitly enforce To_String() API on the context
m_recording.append(to_string((unsigned long long)action));
m_recording.append(" ", 1);
m_recording.append(unique_key);
m_recording.append(" ", 1);
char prob_str[10] = { 0 };
int x = (int)probability;
int d = (int)(fabs(probability - x) * 100000);
sprintf_s(prob_str, 10 * sizeof(char), "%d.%05d", x, d);
m_recording.append(prob_str);
m_recording.append(" | ", 3);
m_recording.append(context.To_String());
m_recording.append("\n");
}
// Gets the content of the recording so far as a string and optionally clears internal content.
string Get_Recording(bool flush = true)
{
if (!flush)
{
return m_recording;
}
string recording = m_recording;
m_recording.clear();
return recording;
}
private:
string m_recording;
};
///
/// Represents a feature in a sparse array.
///
struct Feature
{
float Value;
u32 Id;
bool operator==(Feature other_feature)
{
return Id == other_feature.Id;
}
};
///
/// A sample context class that stores a vector of Features.
///
class SimpleContext
{
public:
SimpleContext(vector<Feature>& features) :
m_features(features)
{ }
vector<Feature>& Get_Features()
{
return m_features;
}
string To_String()
{
string out_string;
const size_t strlen = 35;
char feature_str[strlen] = { 0 };
for (size_t i = 0; i < m_features.size(); i++)
{
int chars;
if (i == 0)
{
chars = sprintf_s(feature_str, strlen, "%d:", m_features[i].Id);
}
else
{
chars = sprintf_s(feature_str, strlen, " %d:", m_features[i].Id);
}
NumberUtils::print_float(feature_str + chars, strlen-chars, m_features[i].Value);
out_string.append(feature_str);
}
return out_string;
}
private:
vector<Feature>& m_features;
};
template <class Ctx>
static u32 Get_Variable_Number_Of_Actions(Ctx& context, u32 default_num_actions)
{
u32 num_actions = default_num_actions;
if (num_actions == UINT_MAX)
{
num_actions = ((IVariableActionContext*)(&context))->Get_Number_Of_Actions();
if (num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
}
return num_actions;
}
///
/// The epsilon greedy exploration algorithm. This is a good choice if you have no idea
/// which actions should be preferred. Epsilon greedy is also computationally cheap.
///
template <class Ctx>
class EpsilonGreedyExplorer : public IExplorer<Ctx>, public IConsumePolicy<Ctx>
{
public:
///
/// The constructor is the only public member, because this should be used with the MwtExplorer.
///
/// @param default_policy A default function which outputs an action given a context.
/// @param epsilon The probability of a random exploration.
/// @param num_actions The number of actions to randomize over.
///
EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon, u32 num_actions) :
m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(num_actions), m_explore(true)
{
if (m_num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
if (m_epsilon < 0 || m_epsilon > 1)
{
throw std::invalid_argument("Epsilon must be between 0 and 1.");
}
}
///
/// Initializes an epsilon greedy explorer with variable number of actions.
///
/// @param default_policy A default function which outputs an action given a context.
/// @param epsilon The probability of a random exploration.
///
EpsilonGreedyExplorer(IPolicy<Ctx>& default_policy, float epsilon) :
m_default_policy(default_policy), m_epsilon(epsilon), m_num_actions(UINT_MAX), m_explore(true)
{
if (m_epsilon < 0 || m_epsilon > 1)
{
throw std::invalid_argument("Epsilon must be between 0 and 1.");
}
static_assert(std::is_base_of<IVariableActionContext, Ctx>::value, "The provided context does not implement variable-action interface.");
}
void Update_Policy(IPolicy<Ctx>& new_policy)
{
m_default_policy = new_policy;
}
void Enable_Explore(bool explore)
{
m_explore = explore;
}
private:
std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
{
u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions);
PRG::prg random_generator(salted_seed);
// Invoke the default policy function to get the action
u32 chosen_action = m_default_policy.Choose_Action(context);
if (chosen_action == 0 || chosen_action > num_actions)
{
throw std::invalid_argument("Action chosen by default policy is not within valid range.");
}
float epsilon = m_explore ? m_epsilon : 0.f;
float action_probability = 0.f;
float base_probability = epsilon / num_actions; // uniform probability
// TODO: check this random generation
if (random_generator.Uniform_Unit_Interval() < 1.f - epsilon)
{
action_probability = 1.f - epsilon + base_probability;
}
else
{
// Get uniform random action ID
u32 actionId = random_generator.Uniform_Int(1, num_actions);
if (actionId == chosen_action)
{
// IF it matches the one chosen by the default policy
// then increase the probability
action_probability = 1.f - epsilon + base_probability;
}
else
{
// Otherwise it's just the uniform probability
action_probability = base_probability;
}
chosen_action = actionId;
}
return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
}
private:
IPolicy<Ctx>& m_default_policy;
const float m_epsilon;
bool m_explore;
const u32 m_num_actions;
};
///
/// In some cases, different actions have a different scores, and you would prefer to
/// choose actions with large scores. Softmax allows you to do that.
///
template <class Ctx>
class SoftmaxExplorer : public IExplorer<Ctx>, public IConsumeScorer<Ctx>
{
public:
///
/// The constructor is the only public member, because this should be used with the MwtExplorer.
///
/// @param default_scorer A function which outputs a score for each action.
/// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
/// @param num_actions The number of actions to randomize over.
///
SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda, u32 num_actions) :
m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(num_actions), m_explore(true)
{
if (m_num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
}
///
/// Initializes a softmax explorer with variable number of actions.
///
/// @param default_scorer A function which outputs a score for each action.
/// @param lambda lambda = 0 implies uniform distribution. Large lambda is equivalent to a max.
///
SoftmaxExplorer(IScorer<Ctx>& default_scorer, float lambda) :
m_default_scorer(default_scorer), m_lambda(lambda), m_num_actions(UINT_MAX), m_explore(true)
{
static_assert(std::is_base_of<IVariableActionContext, Ctx>::value, "The provided context does not implement variable-action interface.");
}
void Update_Scorer(IScorer<Ctx>& new_scorer)
{
m_default_scorer = new_scorer;
}
void Enable_Explore(bool explore)
{
m_explore = explore;
}
private:
std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
{
u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions);
PRG::prg random_generator(salted_seed);
// Invoke the default scorer function
vector<float> scores = m_default_scorer.Score_Actions(context);
u32 num_scores = (u32)scores.size();
if (num_scores != num_actions)
{
throw std::invalid_argument("The number of scores returned by the scorer must equal number of actions");
}
u32 i = 0;
float max_score = -FLT_MAX;
for (i = 0; i < num_scores; i++)
{
if (max_score < scores[i])
{
max_score = scores[i];
}
}
float action_probability = 0.f;
u32 action_index = 0;
if (m_explore)
{
// Create a normalized exponential distribution based on the returned scores
for (i = 0; i < num_scores; i++)
{
scores[i] = exp(m_lambda * (scores[i] - max_score));
}
// Create a discrete_distribution based on the returned weights. This class handles the
// case where the sum of the weights is < or > 1, by normalizing agains the sum.
float total = 0.f;
for (size_t i = 0; i < num_scores; i++)
total += scores[i];
float draw = random_generator.Uniform_Unit_Interval();
float sum = 0.f;
action_probability = 0.f;
action_index = num_scores - 1;
for (u32 i = 0; i < num_scores; i++)
{
scores[i] = scores[i] / total;
sum += scores[i];
if (sum > draw)
{
action_index = i;
action_probability = scores[i];
break;
}
}
}
else
{
float max_score = 0.f;
for (size_t i = 0; i < num_scores; i++)
{
if (max_score < scores[i])
{
max_score = scores[i];
action_index = (u32)i;
}
}
action_probability = 1.f; // Set to 1 since we always pick the highest one.
}
// action id is one-based
return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
}
private:
IScorer<Ctx>& m_default_scorer;
bool m_explore;
const float m_lambda;
const u32 m_num_actions;
};
///
/// GenericExplorer provides complete flexibility. You can create any
/// distribution over actions desired, and it will draw from that.
///
template <class Ctx>
class GenericExplorer : public IExplorer<Ctx>, public IConsumeScorer<Ctx>
{
public:
///
/// The constructor is the only public member, because this should be used with the MwtExplorer.
///
/// @param default_scorer A function which outputs the probability of each action.
/// @param num_actions The number of actions to randomize over.
///
GenericExplorer(IScorer<Ctx>& default_scorer, u32 num_actions) :
m_default_scorer(default_scorer), m_num_actions(num_actions), m_explore(true)
{
if (m_num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
}
///
/// Initializes a generic explorer with variable number of actions.
///
/// @param default_scorer A function which outputs the probability of each action.
///
GenericExplorer(IScorer<Ctx>& default_scorer) :
m_default_scorer(default_scorer), m_num_actions(UINT_MAX), m_explore(true)
{
static_assert(std::is_base_of<IVariableActionContext, Ctx>::value, "The provided context does not implement variable-action interface.");
}
void Update_Scorer(IScorer<Ctx>& new_scorer)
{
m_default_scorer = new_scorer;
}
void Enable_Explore(bool explore)
{
m_explore = explore;
}
private:
std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
{
u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions);
PRG::prg random_generator(salted_seed);
// Invoke the default scorer function
vector<float> weights = m_default_scorer.Score_Actions(context);
u32 num_weights = (u32)weights.size();
if (num_weights != num_actions)
{
throw std::invalid_argument("The number of weights returned by the scorer must equal number of actions");
}
// Create a discrete_distribution based on the returned weights. This class handles the
// case where the sum of the weights is < or > 1, by normalizing agains the sum.
float total = 0.f;
for (size_t i = 0; i < num_weights; i++)
{
if (weights[i] < 0)
{
throw std::invalid_argument("Scores must be non-negative.");
}
total += weights[i];
}
if (total == 0)
{
throw std::invalid_argument("At least one score must be positive.");
}
float draw = random_generator.Uniform_Unit_Interval();
float sum = 0.f;
float action_probability = 0.f;
u32 action_index = num_weights - 1;
for (u32 i = 0; i < num_weights; i++)
{
weights[i] = weights[i] / total;
sum += weights[i];
if (sum > draw)
{
action_index = i;
action_probability = weights[i];
break;
}
}
// action id is one-based
return std::tuple<u32, float, bool>(action_index + 1, action_probability, true);
}
private:
IScorer<Ctx>& m_default_scorer;
bool m_explore;
const u32 m_num_actions;
};
///
/// The tau-first explorer collects exactly tau uniform random exploration events, and then
/// uses the default policy thereafter.
///
template <class Ctx>
class TauFirstExplorer : public IExplorer<Ctx>, public IConsumePolicy<Ctx>
{
public:
///
/// The constructor is the only public member, because this should be used with the MwtExplorer.
///
/// @param default_policy A default policy after randomization finishes.
/// @param tau The number of events to be uniform over.
/// @param num_actions The number of actions to randomize over.
///
TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau, u32 num_actions) :
m_default_policy(default_policy), m_tau(tau), m_num_actions(num_actions), m_explore(true)
{
if (m_num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
}
///
/// Initializes a tau-first explorer with variable number of actions.
///
/// @param default_policy A default policy after randomization finishes.
/// @param tau The number of events to be uniform over.
///
TauFirstExplorer(IPolicy<Ctx>& default_policy, u32 tau) :
m_default_policy(default_policy), m_tau(tau), m_num_actions(UINT_MAX), m_explore(true)
{
static_assert(std::is_base_of<IVariableActionContext, Ctx>::value, "The provided context does not implement variable-action interface.");
}
void Update_Policy(IPolicy<Ctx>& new_policy)
{
m_default_policy = new_policy;
}
void Enable_Explore(bool explore)
{
m_explore = explore;
}
private:
std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
{
u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions);
PRG::prg random_generator(salted_seed);
u32 chosen_action = 0;
float action_probability = 0.f;
bool log_action;
if (m_tau && m_explore)
{
m_tau--;
u32 actionId = random_generator.Uniform_Int(1, num_actions);
action_probability = 1.f / num_actions;
chosen_action = actionId;
log_action = true;
}
else
{
// Invoke the default policy function to get the action
chosen_action = m_default_policy.Choose_Action(context);
if (chosen_action == 0 || chosen_action > num_actions)
{
throw std::invalid_argument("Action chosen by default policy is not within valid range.");
}
action_probability = 1.f;
log_action = false;
}
return std::tuple<u32, float, bool>(chosen_action, action_probability, log_action);
}
private:
IPolicy<Ctx>& m_default_policy;
bool m_explore;
u32 m_tau;
const u32 m_num_actions;
};
///
/// The Bootstrap explorer randomizes over the actions chosen by a set of default policies.
/// This performs well statistically but can be computationally expensive.
///
template <class Ctx>
class BootstrapExplorer : public IExplorer<Ctx>, public IConsumePolicies<Ctx>
{
public:
///
/// The constructor is the only public member, because this should be used with the MwtExplorer.
///
/// @param default_policy_functions A set of default policies to be uniform random over.
/// The policy pointers must be valid throughout the lifetime of this explorer.
/// @param num_actions The number of actions to randomize over.
///
BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions, u32 num_actions) :
m_default_policy_functions(default_policy_functions),
m_num_actions(num_actions), m_explore(true), m_bags((u32)default_policy_functions.size())
{
if (m_num_actions < 1)
{
throw std::invalid_argument("Number of actions must be at least 1.");
}
if (m_bags < 1)
{
throw std::invalid_argument("Number of bags must be at least 1.");
}
}
///
/// Initializes a bootstrap explorer with variable number of actions.
///
/// @param default_policy_functions A set of default policies to be uniform random over.
/// The policy pointers must be valid throughout the lifetime of this explorer.
///
BootstrapExplorer(vector<unique_ptr<IPolicy<Ctx>>>& default_policy_functions) :
m_default_policy_functions(default_policy_functions),
m_num_actions(UINT_MAX), m_explore(true), m_bags((u32)default_policy_functions.size())
{
if (m_bags < 1)
{
throw std::invalid_argument("Number of bags must be at least 1.");
}
static_assert(std::is_base_of<IVariableActionContext, Ctx>::value, "The provided context does not implement variable-action interface.");
}
void Update_Policy(vector<unique_ptr<IPolicy<Ctx>>>& new_policy_functions)
{
m_default_policy_functions = move(new_policy_functions);
}
void Enable_Explore(bool explore)
{
m_explore = explore;
}
private:
std::tuple<u32, float, bool> Choose_Action(u64 salted_seed, Ctx& context)
{
u32 num_actions = ::Get_Variable_Number_Of_Actions(context, m_num_actions);
PRG::prg random_generator(salted_seed);
// Select bag
u32 chosen_bag = random_generator.Uniform_Int(0, m_bags - 1);
// Invoke the default policy function to get the action
u32 chosen_action = 0;
float action_probability = 0.f;
if (m_explore)
{
u32 action_from_bag = 0;
vector<u32> actions_selected;
for (size_t i = 0; i < num_actions; i++)
{
actions_selected.push_back(0);
}
// Invoke the default policy function to get the action
for (u32 current_bag = 0; current_bag < m_bags; current_bag++)
{
// TODO: can VW predict for all bags on one call? (returning all actions at once)
// if we trigger into VW passing an index to invoke bootstrap scoring, and if VW model changes while we are doing so,
// we could end up calling the wrong bag
action_from_bag = m_default_policy_functions[current_bag]->Choose_Action(context);
if (action_from_bag == 0 || action_from_bag > num_actions)
{
throw std::invalid_argument("Action chosen by default policy is not within valid range.");
}
if (current_bag == chosen_bag)
{
chosen_action = action_from_bag;
}
//this won't work if actions aren't 0 to Count
actions_selected[action_from_bag - 1]++; // action id is one-based
}
action_probability = (float)actions_selected[chosen_action - 1] / m_bags; // action id is one-based
}
else
{
chosen_action = m_default_policy_functions[0]->Choose_Action(context);
action_probability = 1.f;
}
return std::tuple<u32, float, bool>(chosen_action, action_probability, true);
}
private:
vector<unique_ptr<IPolicy<Ctx>>>& m_default_policy_functions;
bool m_explore;
const u32 m_bags;
const u32 m_num_actions;
};
} // End namespace MultiWorldTestingCpp
/*! @} End of Doxygen Groups*/

Просмотреть файл

@ -1,2 +0,0 @@
all:
cd ..; $(MAKE)

Просмотреть файл

@ -1,4 +0,0 @@
// vw_explore.cpp : Defines the entry point for the console application.
//
#include "MwtExplorer.h"

Просмотреть файл

@ -1,155 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{ACE47E98-488C-4CDF-B9F1-36337B2855AD}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>vw_explore_static</RootNamespace>
<ProjectName>explore_static</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>StaticLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Platform)'=='x64'">
<BoostIncludeDir>c:\boost\x64\include\boost-1_56</BoostIncludeDir>
<ZlibIncludeDir>..\..\..\zlib-1.2.8</ZlibIncludeDir>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<DebugInformationFormat>ProgramDatabase</DebugInformationFormat>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="MWTExplorer.h" />
<ClInclude Include="utility.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="explore.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -1,282 +0,0 @@
/*******************************************************************/
// Classes declared in this file are intended for internal use only.
/*******************************************************************/
#pragma once
#include <stdint.h>
#include <sys/types.h> /* defines size_t */
#ifdef WIN32
typedef unsigned __int64 u64;
typedef unsigned __int32 u32;
typedef unsigned __int16 u16;
typedef unsigned __int8 u8;
typedef signed __int64 i64;
typedef signed __int32 i32;
typedef signed __int16 i16;
typedef signed __int8 i8;
// cross-platform float to_string
#else
typedef uint64_t u64;
typedef uint32_t u32;
typedef uint16_t u16;
typedef uint8_t u8;
typedef int64_t i64;
typedef int32_t i32;
typedef int16_t i16;
typedef int8_t i8;
// cross-platform float to_string
#define sprintf_s snprintf
#endif
typedef unsigned char byte;
#include <string>
#include <stdint.h>
#include <math.h>
/*!
* \addtogroup MultiWorldTestingCpp
* @{
*/
MWT_NAMESPACE {
//
// MurmurHash3, by Austin Appleby
//
// Originals at:
// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp
// http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.h
//
// Notes:
// 1) this code assumes we can read a 4-byte value from any address
// without crashing (i.e non aligned access is supported). This is
// not a problem on Intel/x86/AMD64 machines (including new Macs)
// 2) It produces different results on little-endian and big-endian machines.
//
//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
// domain. The author hereby disclaims copyright to this source code.
// Note - The x86 and x64 versions do _not_ produce the same results, as the
// algorithms are optimized for their respective platforms. You can still
// compile and run any of them on any platform, but your performance with the
// non-native version will be less than optimal.
//-----------------------------------------------------------------------------
// Platform-specific functions and macros
#if defined(_MSC_VER) // Microsoft Visual Studio
# include <stdint.h>
# include <stdlib.h>
# define ROTL32(x,y) _rotl(x,y)
# define BIG_CONSTANT(x) (x)
#else // Other compilers
# include <stdint.h> /* defines uint32_t etc */
inline uint32_t rotl32(uint32_t x, int8_t r)
{
return (x << r) | (x >> (32 - r));
}
# define ROTL32(x,y) rotl32(x,y)
# define BIG_CONSTANT(x) (x##LLU)
#endif // !defined(_MSC_VER)
struct murmur_hash {
//-----------------------------------------------------------------------------
// Block read - if your platform needs to do endian-swapping or can only
// handle aligned reads, do the conversion here
private:
static inline uint32_t getblock(const uint32_t * p, int i)
{
return p[i];
}
//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche
static inline uint32_t fmix(uint32_t h)
{
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h;
}
//-----------------------------------------------------------------------------
public:
uint32_t uniform_hash(const void * key, size_t len, uint32_t seed)
{
const uint8_t * data = (const uint8_t*)key;
const int nblocks = (int)len / 4;
uint32_t h1 = seed;
const uint32_t c1 = 0xcc9e2d51;
const uint32_t c2 = 0x1b873593;
// --- body
const uint32_t * blocks = (const uint32_t *)(data + nblocks * 4);
for (int i = -nblocks; i; i++) {
uint32_t k1 = getblock(blocks, i);
k1 *= c1;
k1 = ROTL32(k1, 15);
k1 *= c2;
h1 ^= k1;
h1 = ROTL32(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
}
// --- tail
const uint8_t * tail = (const uint8_t*)(data + nblocks * 4);
uint32_t k1 = 0;
switch (len & 3) {
case 3: k1 ^= tail[2] << 16;
case 2: k1 ^= tail[1] << 8;
case 1: k1 ^= tail[0];
k1 *= c1; k1 = ROTL32(k1, 15); k1 *= c2; h1 ^= k1;
}
// --- finalization
h1 ^= len;
return fmix(h1);
}
};
class HashUtils
{
public:
static u64 Compute_Id_Hash(const std::string& unique_id)
{
size_t ret = 0;
const char *p = unique_id.c_str();
while (*p != '\0')
if (*p >= '0' && *p <= '9')
ret = 10 * ret + *(p++) - '0';
else
{
murmur_hash foo;
return foo.uniform_hash(unique_id.c_str(), unique_id.size(), 0);
}
return ret;
}
};
const size_t max_int = 100000;
const float max_float = max_int;
const float min_float = 0.00001f;
const size_t max_digits = (size_t) roundf((float) (-log(min_float) / log(10.)));
class NumberUtils
{
public:
template<bool trailing_zeros>
static void print_mantissa(char*& begin, float f)
{ // helper for print_float
char values[10];
size_t v = (size_t)f;
size_t digit = 0;
size_t first_nonzero = 0;
for (size_t max = 1; max <= v; ++digit)
{
size_t max_next = max * 10;
char v_mod = (char) (v % max_next / max);
if (!trailing_zeros && v_mod != '\0' && first_nonzero == 0)
first_nonzero = digit;
values[digit] = '0' + v_mod;
max = max_next;
}
if (!trailing_zeros)
for (size_t i = max_digits; i > digit; i--)
*begin++ = '0';
while (digit > first_nonzero)
*begin++ = values[--digit];
}
static void print_float(char* begin, size_t size, float f)
{
bool sign = false;
if (f < 0.f)
sign = true;
float unsigned_f = fabsf(f);
if (unsigned_f < max_float && unsigned_f > min_float)
{
if (sign)
*begin++ = '-';
print_mantissa<true>(begin, unsigned_f);
unsigned_f -= (size_t)unsigned_f;
unsigned_f *= max_int;
if (unsigned_f >= 1.f)
{
*begin++ = '.';
print_mantissa<false>(begin, unsigned_f);
}
}
else if (unsigned_f == 0.)
*begin++ = '0';
else
{
sprintf_s(begin, size, "%g", f);
return;
}
*begin = '\0';
return;
}
};
//A quick implementation similar to drand48 for cross-platform compatibility
namespace PRG {
const uint64_t a = 0xeece66d5deece66dULL;
const uint64_t c = 2147483647;
const int bias = 127 << 23;
union int_float {
int32_t i;
float f;
};
struct prg {
private:
uint64_t v;
public:
prg() { v = c; }
prg(uint64_t initial) { v = initial; }
float merand48(uint64_t& initial)
{
initial = a * initial + c;
int_float temp;
temp.i = ((initial >> 25) & 0x7FFFFF) | bias;
return temp.f - 1;
}
float Uniform_Unit_Interval()
{
return merand48(v);
}
uint32_t Uniform_Int(uint32_t low, uint32_t high)
{
merand48(v);
uint32_t ret = low + ((v >> 25) % (high - low + 1));
return ret;
}
};
}
}
/*! @} End of Doxygen Groups*/

Просмотреть файл

@ -1,51 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<Text Include="ReadMe.txt" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="targetver.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="Interaction.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="utility.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="MWT.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="Logger.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="OfflineEvaluator.h">
<Filter>Source Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="vw_explore.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -1,934 +0,0 @@
#include "CppUnitTest.h"
#include "MWTExploreTests.h"
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
#define COUNT_INVALID(block) try { block } catch (std::invalid_argument) { num_ex++; }
#define COUNT_BAD_CALL(block) try { block } catch (std::invalid_argument) { num_ex++; }
namespace vw_explore_tests
{
TEST_CLASS(VWExploreUnitTests)
{
public:
TEST_METHOD(Epsilon_Greedy)
{
int num_actions = 10;
float epsilon = 0.f; // No randomization
string unique_key = "1001";
int params = 101;
TestPolicy<TestContext> my_policy(params, num_actions);
TestContext my_context;
TestRecorder<TestContext> my_recorder;
MwtExplorer<TestContext> mwt("salt", my_recorder);
EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
u32 expected_action = my_policy.Choose_Action(my_context);
u32 chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
Assert::AreEqual(expected_action, chosen_action);
chosen_action = mwt.Choose_Action(explorer, unique_key, my_context);
Assert::AreEqual(expected_action, chosen_action);
float expected_probs[2] = { 1.f, 1.f };
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
this->Test_Interactions(interactions, 2, expected_probs);
}
TEST_METHOD(Epsilon_Greedy_Random)
{
int num_actions = 10;
float epsilon = 0.5f; // Verify that about half the time the default policy is chosen
int params = 101;
TestPolicy<TestContext> my_policy(params, num_actions);
TestContext my_context;
EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions); // Initialize in fixed # action mode
this->Epsilon_Greedy_Random_Context(num_actions, my_context, explorer, my_policy);
}
TEST_METHOD(Epsilon_Greedy_Random_Var_Context)
{
int num_actions = 10;
float epsilon = 0.5f; // Verify that about half the time the default policy is chosen
int params = 101;
TestPolicy<TestVarContext> my_policy(params, num_actions);
TestVarContext my_context(num_actions);
EpsilonGreedyExplorer<TestVarContext> explorer(my_policy, epsilon); // Initialize in variable # action mode
// Test results using context that supports variable # action interface but returns fixed # action.
this->Epsilon_Greedy_Random_Context(num_actions, my_context, explorer, my_policy);
}
TEST_METHOD(Epsilon_Greedy_Toggle_Exploration)
{
int num_actions = 10;
float epsilon = 0.5f;
int params = 101;
TestPolicy<TestContext> my_policy(params, num_actions);
TestContext my_context;
TestRecorder<TestContext> my_recorder;
MwtExplorer<TestContext> mwt("salt", my_recorder);
EpsilonGreedyExplorer<TestContext> explorer(my_policy, epsilon, num_actions);
u32 policy_action = my_policy.Choose_Action(my_context);
int times_choose = 10000;
int times_policy_action_chosen = 0;
explorer.Enable_Explore(false);
// Verify that all the time the default policy is chosen
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::AreEqual(times_choose, times_policy_action_chosen);
explorer.Enable_Explore(true);
times_policy_action_chosen = 0;
// Verify that about half the time the default policy is chosen
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1);
}
TEST_METHOD(Tau_First)
{
int num_actions = 10;
u32 tau = 0;
int params = 101;
TestPolicy<TestContext> my_policy(params, num_actions);
TestRecorder<TestContext> my_recorder;
TestContext my_context;
MwtExplorer<TestContext> mwt("salt", my_recorder);
TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
u32 expected_action = my_policy.Choose_Action(my_context);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
Assert::AreEqual(expected_action, chosen_action);
// tau = 0 means no randomization and no logging
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
this->Test_Interactions(interactions, 0, nullptr);
}
TEST_METHOD(Tau_First_Random)
{
int num_actions = 10;
u32 tau = 2;
TestContext my_context;
TestPolicy<TestContext> my_policy(99, num_actions);
TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
this->Tau_First_Random_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Tau_First_Random_Var_Context)
{
int num_actions = 10;
u32 tau = 2;
TestVarContext my_context(num_actions);
TestPolicy<TestVarContext> my_policy(99, num_actions);
TauFirstExplorer<TestVarContext> explorer(my_policy, tau);
this->Tau_First_Random_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Tau_First_Toggle_Exploration)
{
int num_actions = 10;
u32 tau = 2;
TestPolicy<TestContext> my_policy(99, num_actions);
TestRecorder<TestContext> my_recorder;
TestContext my_context;
MwtExplorer<TestContext> mwt("salt", my_recorder);
TauFirstExplorer<TestContext> explorer(my_policy, tau, num_actions);
u32 policy_action = my_policy.Choose_Action(my_context);
int times_choose = 10000;
int times_policy_action_chosen = 0;
explorer.Enable_Explore(false);
// Verify that all the time the default policy is chosen
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::AreEqual(times_choose, times_policy_action_chosen);
explorer.Enable_Explore(true);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
// Tau expired, did not explore
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
Assert::AreEqual((u32)10, chosen_action);
// Only 2 interactions logged, 3rd one should not be stored
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
float expected_probs[2] = { .1f, .1f };
this->Test_Interactions(interactions, 2, expected_probs);
}
TEST_METHOD(Bootstrap)
{
int num_actions = 10;
int params = 101;
TestRecorder<TestContext> my_recorder;
vector<unique_ptr<IPolicy<TestContext>>> policies;
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params, num_actions)));
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params + 1, num_actions)));
TestContext my_context;
MwtExplorer<TestContext> mwt("c++-test", my_recorder);
BootstrapExplorer<TestContext> explorer(policies, num_actions);
u32 expected_action1 = policies[0]->Choose_Action(my_context);
u32 expected_action2 = policies[1]->Choose_Action(my_context);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
Assert::AreEqual(expected_action2, chosen_action);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
Assert::AreEqual(expected_action1, chosen_action);
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
float expected_probs[2] = { .5f, .5f };
this->Test_Interactions(interactions, 2, expected_probs);
}
TEST_METHOD(Bootstrap_Random)
{
int num_actions = 10;
int params = 101;
TestContext my_context;
vector<unique_ptr<IPolicy<TestContext>>> policies;
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params, num_actions)));
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params + 1, num_actions)));
BootstrapExplorer<TestContext> explorer(policies, num_actions);
this->Bootstrap_Random_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Bootstrap_Random_Var_Context)
{
int num_actions = 10;
int params = 101;
TestVarContext my_context(num_actions);
vector<unique_ptr<IPolicy<TestVarContext>>> policies;
policies.push_back(unique_ptr<IPolicy<TestVarContext>>(new TestPolicy<TestVarContext>(params, num_actions)));
policies.push_back(unique_ptr<IPolicy<TestVarContext>>(new TestPolicy<TestVarContext>(params + 1, num_actions)));
BootstrapExplorer<TestVarContext> explorer(policies);
this->Bootstrap_Random_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Bootstrap_Toggle_Exploration)
{
int num_actions = 10;
int params = 101;
TestRecorder<TestContext> my_recorder;
vector<unique_ptr<IPolicy<TestContext>>> policies;
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params, num_actions)));
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestPolicy<TestContext>(params + 1, num_actions)));
TestContext my_context;
MwtExplorer<TestContext> mwt("c++-test", my_recorder);
BootstrapExplorer<TestContext> explorer(policies, num_actions);
u32 policy_action = policies[0]->Choose_Action(my_context);
int times_choose = 10000;
int times_policy_action_chosen = 0;
explorer.Enable_Explore(false);
// Verify that all the time the first policy is chosen
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::AreEqual(times_choose, times_policy_action_chosen);
explorer.Enable_Explore(true);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
// Two bags choosing different actions so prob of each is 1/2
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
float* expected_probs = new float[times_choose + 2];
for (int i = 0; i < times_choose; i++)
{
expected_probs[i] = 1.f;
}
expected_probs[times_choose] = .5f;
expected_probs[times_choose + 1] = .5f;
this->Test_Interactions(interactions, times_choose + 2, expected_probs);
delete[] expected_probs;
}
TEST_METHOD(Softmax)
{
int scorer_arg = 7;
int num_actions = 10;
float lambda = 0.f;
TestContext my_context;
TestScorer<TestContext> my_scorer(scorer_arg, num_actions);
SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
this->Softmax_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Softmax_Var_Context)
{
int scorer_arg = 7;
int num_actions = 10;
float lambda = 0.f;
TestVarContext my_context(num_actions);
TestScorer<TestVarContext> my_scorer(scorer_arg, num_actions);
SoftmaxExplorer<TestVarContext> explorer(my_scorer, lambda);
this->Softmax_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Softmax_Scores)
{
int num_actions = 10;
float lambda = 0.5f;
int scorer_arg = 7;
TestScorer<TestContext> my_scorer(scorer_arg, num_actions, /* uniform = */ false);
TestRecorder<TestContext> my_recorder;
TestContext my_context;
MwtExplorer<TestContext> mwt("salt", my_recorder);
SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
size_t num_interactions = interactions.size();
Assert::AreEqual(3, (int)num_interactions);
for (size_t i = 0; i < num_interactions; i++)
{
Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability);
}
}
TEST_METHOD(Softmax_Toggle_Exploration)
{
int num_actions = 10;
float lambda = 0.5f;
int scorer_arg = 7;
TestScorer<TestContext> my_scorer(scorer_arg, num_actions, /* uniform = */ false);
TestRecorder<TestContext> my_recorder;
TestContext my_context;
MwtExplorer<TestContext> mwt("salt", my_recorder);
SoftmaxExplorer<TestContext> explorer(my_scorer, lambda, num_actions);
vector<float> scores = my_scorer.Score_Actions(my_context);
float max_score = 0.f;
u32 policy_action = 0;
for (size_t i = 0; i < scores.size(); i++)
{
if (max_score < scores[i])
{
max_score = scores[i];
policy_action = (u32)i + 1;
}
}
int times_choose = 10000;
int times_policy_action_chosen = 0;
explorer.Enable_Explore(false);
// Verify that all the time the highest score action is chosen
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::AreEqual(times_choose, times_policy_action_chosen);
explorer.Enable_Explore(true);
u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
vector<TestInteraction<TestContext>> interactions = my_recorder.Get_All_Interactions();
size_t num_interactions = interactions.size();
Assert::AreEqual(times_choose + 3, (int)num_interactions);
for (size_t i = 0; i < (size_t)times_choose; i++)
{
Assert::AreEqual(1.f, interactions[i].Probability);
}
for (size_t i = times_choose; i < num_interactions; i++)
{
Assert::AreNotEqual(1.f / num_actions, interactions[i].Probability);
}
}
TEST_METHOD(Generic)
{
int num_actions = 10;
int scorer_arg = 7;
TestContext my_context;
TestScorer<TestContext> my_scorer(scorer_arg, num_actions);
GenericExplorer<TestContext> explorer(my_scorer, num_actions);
this->Generic_Context(num_actions, my_context, explorer);
}
TEST_METHOD(Generic_Var_Context)
{
int num_actions = 10;
int scorer_arg = 7;
TestVarContext my_context(num_actions);
TestScorer<TestVarContext> my_scorer(scorer_arg, num_actions);
GenericExplorer<TestVarContext> explorer(my_scorer);
this->Generic_Context(num_actions, my_context, explorer);
}
TEST_METHOD(End_To_End_Epsilon_Greedy)
{
int num_actions = 10;
float epsilon = 0.5f;
int params = 101;
TestSimplePolicy my_policy(params, num_actions);
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
this->End_To_End(mwt, explorer, my_recorder);
}
TEST_METHOD(End_To_End_Tau_First)
{
int num_actions = 10;
u32 tau = 5;
int params = 101;
TestSimplePolicy my_policy(params, num_actions);
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
TauFirstExplorer<SimpleContext> explorer(my_policy, tau, num_actions);
this->End_To_End(mwt, explorer, my_recorder);
}
TEST_METHOD(End_To_End_Bootstrap)
{
int num_actions = 10;
u32 bags = 2;
int params = 101;
StringRecorder<SimpleContext> my_recorder;
vector<unique_ptr<IPolicy<SimpleContext>>> policies;
policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
policies.push_back(unique_ptr<IPolicy<SimpleContext>>(new TestSimplePolicy(params, num_actions)));
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
BootstrapExplorer<SimpleContext> explorer(policies, num_actions);
this->End_To_End(mwt, explorer, my_recorder);
}
TEST_METHOD(End_To_End_Softmax)
{
int num_actions = 10;
float lambda = 0.5f;
int scorer_arg = 7;
TestSimpleScorer my_scorer(scorer_arg, num_actions);
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
SoftmaxExplorer<SimpleContext> explorer(my_scorer, lambda, num_actions);
this->End_To_End(mwt, explorer, my_recorder);
}
TEST_METHOD(End_To_End_Generic)
{
int num_actions = 10;
int scorer_arg = 7;
TestSimpleScorer my_scorer(scorer_arg, num_actions);
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
GenericExplorer<SimpleContext> explorer(my_scorer, num_actions);
this->End_To_End(mwt, explorer, my_recorder);
}
TEST_METHOD(PRG_Coverage)
{
const u32 NUM_ACTIONS_COVER = 100;
float C = 5.0f;
// We could use many fewer bits (e.g. u8) per bin since we're throwing uniformly at
// random, but this is safer in case we change things
u32 bins[NUM_ACTIONS_COVER] = { 0 };
u32 num_balls = (u32)(NUM_ACTIONS_COVER * log(NUM_ACTIONS_COVER) + C * NUM_ACTIONS_COVER);
PRG::prg rg;
u32 i;
for (i = 0; i < num_balls; i++)
{
bins[rg.Uniform_Int(0, NUM_ACTIONS_COVER - 1)]++;
}
// Ensure all actions are covered
for (i = 0; i < NUM_ACTIONS_COVER; i++)
{
Assert::IsTrue(bins[i] > 0);
}
}
TEST_METHOD(Serialized_String)
{
int num_actions = 10;
float epsilon = 0.5f;
int params = 101;
TestSimplePolicy my_policy(params, num_actions);
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
vector<Feature> features1;
features1.push_back({ 0.5f, 1 });
SimpleContext context1(features1);
u32 expected_action = my_policy.Choose_Action(context1);
string unique_key1 = "key1";
u32 chosen_action1 = mwt.Choose_Action(explorer, unique_key1, context1);
vector<Feature> features2;
features2.push_back({ -99999.5f, 123456789 });
features2.push_back({ 1.5f, 39 });
SimpleContext context2(features2);
string unique_key2 = "key2";
u32 chosen_action2 = mwt.Choose_Action(explorer, unique_key2, context2);
string actual_log = my_recorder.Get_Recording();
// Use hard-coded string to be independent of sprintf
char* expected_log = "2 key1 0.55000 | 1:.5\n2 key2 0.55000 | 123456789:-99999.5 39:1.5\n";
Assert::AreEqual(expected_log, actual_log.c_str());
}
TEST_METHOD(Serialized_String_Random)
{
PRG::prg rand;
int num_actions = 10;
int params = 101;
TestSimplePolicy my_policy(params, num_actions);
char expected_log[100] = { 0 };
for (int i = 0; i < 10000; i++)
{
StringRecorder<SimpleContext> my_recorder;
MwtExplorer<SimpleContext> mwt("c++-test", my_recorder);
EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, 0.f, num_actions);
Feature feature;
feature.Value = (rand.Uniform_Unit_Interval() - 0.5f) * rand.Uniform_Int(0, 100000);
feature.Id = i;
vector<Feature> features;
features.push_back(feature);
SimpleContext my_context(features);
u32 action = mwt.Choose_Action(explorer, "", my_context);
string actual_log = my_recorder.Get_Recording();
ostringstream expected_stream;
expected_stream << std::fixed << std::setprecision(10) << feature.Value;
string expected_str = expected_stream.str();
if (expected_str[0] == '0')
{
expected_str = expected_str.erase(0, 1);
}
sprintf_s(expected_log, "%d %s %.5f | %d:%s",
action, "", 1.f, i, expected_str.c_str());
size_t length = actual_log.length() - 1;
int compare_result = string(expected_log).compare(0, length, actual_log, 0, length);
Assert::AreEqual(0, compare_result);
}
}
TEST_METHOD(Usage_Bad_Arguments)
{
int num_ex = 0;
int params = 101;
TestPolicy<TestContext> my_policy(params, 0);
TestScorer<TestContext> my_scorer(params, 0);
vector<unique_ptr<IPolicy<TestContext>>> policies;
COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, .5f, 0);) // Invalid # actions, must be > 0
COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, 1.5f, 10);) // Invalid epsilon, must be in [0,1]
COUNT_INVALID(EpsilonGreedyExplorer<TestContext> explorer(my_policy, -.5f, 10);) // Invalid epsilon, must be in [0,1]
COUNT_INVALID(BootstrapExplorer<TestContext> explorer(policies, 0);) // Invalid # actions, must be > 0
COUNT_INVALID(BootstrapExplorer<TestContext> explorer(policies, 1);) // Invalid # bags, must be > 0
COUNT_INVALID(TauFirstExplorer<TestContext> explorer(my_policy, 1, 0);) // Invalid # actions, must be > 0
COUNT_INVALID(SoftmaxExplorer<TestContext> explorer(my_scorer, .5f, 0);) // Invalid # actions, must be > 0
COUNT_INVALID(GenericExplorer<TestContext> explorer(my_scorer, 0);) // Invalid # actions, must be > 0
Assert::AreEqual(8, num_ex);
}
TEST_METHOD(Usage_Bad_Policy)
{
int num_ex = 0;
// Default policy returns action outside valid range
COUNT_BAD_CALL
(
TestRecorder<TestContext> recorder;
TestBadPolicy policy;
TestContext context;
MwtExplorer<TestContext> mwt("salt", recorder);
EpsilonGreedyExplorer<TestContext> explorer(policy, 0.f, (u32)1);
u32 expected_action = mwt.Choose_Action(explorer, "1001", context);
)
COUNT_BAD_CALL
(
TestRecorder<TestContext> recorder;
TestBadPolicy policy;
TestContext context;
MwtExplorer<TestContext> mwt("salt", recorder);
TauFirstExplorer<TestContext> explorer(policy, (u32)0, (u32)1);
mwt.Choose_Action(explorer, "test", context);
)
COUNT_BAD_CALL
(
TestRecorder<TestContext> recorder;
TestContext context;
vector<unique_ptr<IPolicy<TestContext>>> policies;
policies.push_back(unique_ptr<IPolicy<TestContext>>(new TestBadPolicy()));
MwtExplorer<TestContext> mwt("salt", recorder);
BootstrapExplorer<TestContext> explorer(policies, (u32)1);
mwt.Choose_Action(explorer, "test", context);
)
Assert::AreEqual(3, num_ex);
}
TEST_METHOD(Usage_Bad_Scorer)
{
int num_ex = 0;
// Default policy returns action outside valid range
COUNT_BAD_CALL
(
u32 num_actions = 1;
FixedScorer scorer(num_actions, -1);
MwtExplorer<TestContext> mwt("salt", TestRecorder<TestContext>());
GenericExplorer<TestContext> explorer(scorer, num_actions);
mwt.Choose_Action(explorer, "test", TestContext());
)
COUNT_BAD_CALL
(
u32 num_actions = 1;
FixedScorer scorer(num_actions, 0);
MwtExplorer<TestContext> mwt("salt", TestRecorder<TestContext>());
GenericExplorer<TestContext> explorer(scorer, num_actions);
mwt.Choose_Action(explorer, "test", TestContext());
)
Assert::AreEqual(2, num_ex);
}
TEST_METHOD(Custom_Context)
{
int num_actions = 10;
float epsilon = 0.f; // No randomization
string unique_key = "1001";
TestSimplePolicy my_policy(0, num_actions);
TestSimpleRecorder my_recorder;
MwtExplorer<SimpleContext> mwt("salt", my_recorder);
vector<Feature> features;
features.push_back({ 0.5f, 1 });
features.push_back({ 1.5f, 6 });
features.push_back({ -5.3f, 13 });
SimpleContext custom_context(features);
EpsilonGreedyExplorer<SimpleContext> explorer(my_policy, epsilon, num_actions);
u32 chosen_action = mwt.Choose_Action(explorer, unique_key, custom_context);
Assert::AreEqual((u32)1, chosen_action);
float expected_probs[1] = { 1.f };
vector<TestInteraction<SimpleContext>> interactions = my_recorder.Get_All_Interactions();
Assert::AreEqual(1, (int)interactions.size());
SimpleContext* returned_context = &interactions[0].Context;
size_t onf = features.size();
Feature* of = &features[0];
vector<Feature>& returned_features = returned_context->Get_Features();
size_t rnf = returned_features.size();
Feature* rf = &returned_features[0];
Assert::AreEqual(rnf, onf);
for (size_t i = 0; i < rnf; i++)
{
Assert::AreEqual(of[i].Id, rf[i].Id);
Assert::AreEqual(of[i].Value, rf[i].Value);
}
}
TEST_METHOD_INITIALIZE(Test_Initialize)
{
}
TEST_METHOD_CLEANUP(Test_Cleanup)
{
}
private:
template <class TContext>
void Epsilon_Greedy_Random_Context(int num_actions, TContext& my_context, EpsilonGreedyExplorer<TContext>& explorer, TestPolicy<TContext>& my_policy)
{
TestRecorder<TContext> my_recorder;
MwtExplorer<TContext> mwt("salt", my_recorder);
u32 policy_action = my_policy.Choose_Action(my_context);
int times_choose = 10000;
int times_policy_action_chosen = 0;
for (int i = 0; i < times_choose; i++)
{
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i), my_context);
if (chosen_action == policy_action)
{
times_policy_action_chosen++;
}
}
Assert::IsTrue(abs((double)times_policy_action_chosen / times_choose - 0.5) < 0.1);
}
template <class TContext>
void Tau_First_Random_Context(int num_actions, TContext& my_context, TauFirstExplorer<TContext>& explorer)
{
TestRecorder<TContext> my_recorder;
MwtExplorer<TContext> mwt("salt", my_recorder);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
// Tau expired, did not explore
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
Assert::AreEqual((u32)10, chosen_action);
// Only 2 interactions logged, 3rd one should not be stored
vector<TestInteraction<TContext>> interactions = my_recorder.Get_All_Interactions();
float expected_probs[2] = { .1f, .1f };
this->Test_Interactions(interactions, 2, expected_probs);
}
template <class TContext>
void Bootstrap_Random_Context(int num_actions, TContext& my_context, BootstrapExplorer<TContext>& explorer)
{
TestRecorder<TContext> my_recorder;
MwtExplorer<TContext> mwt("c++-test", my_recorder);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
// Two bags choosing different actions so prob of each is 1/2
vector<TestInteraction<TContext>> interactions = my_recorder.Get_All_Interactions();
float expected_probs[2] = { .5f, .5f };
this->Test_Interactions(interactions, 2, expected_probs);
}
template <class TContext>
void Softmax_Context(int num_actions, TContext& my_context, SoftmaxExplorer<TContext>& explorer)
{
u32 NUM_ACTIONS_COVER = 100;
float C = 5.0f;
TestRecorder<TContext> my_recorder;
MwtExplorer<TContext> mwt("salt", my_recorder);
// Scale C up since we have fewer interactions
u32 num_decisions = (u32)(num_actions * log(num_actions * 1.0) + log(NUM_ACTIONS_COVER * 1.0 / num_actions) * C * num_actions);
// The () following the array should ensure zero-initialization
u32* actions = new u32[num_actions]();
u32 i;
for (i = 0; i < num_decisions; i++)
{
u32 action = mwt.Choose_Action(explorer, this->Get_Unique_Key(i + 1), my_context);
// Action IDs are 1-based
actions[action - 1]++;
}
// Ensure all actions are covered
for (i = 0; i < (u32)num_actions; i++)
{
Assert::IsTrue(actions[i] > 0);
}
float* expected_probs = new float[num_decisions];
for (i = 0; i < num_decisions; i++)
{
// Our default scorer currently assigns equal weight to each action
expected_probs[i] = 1.0f / num_actions;
}
vector<TestInteraction<TContext>> interactions = my_recorder.Get_All_Interactions();
this->Test_Interactions(interactions, num_decisions, expected_probs);
delete actions;
delete expected_probs;
}
template <class TContext>
void Generic_Context(int num_actions, TContext& my_context, GenericExplorer<TContext>& explorer)
{
TestRecorder<TContext> my_recorder;
MwtExplorer<TContext> mwt("salt", my_recorder);
u32 chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(1), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(2), my_context);
chosen_action = mwt.Choose_Action(explorer, this->Get_Unique_Key(3), my_context);
vector<TestInteraction<TContext>> interactions = my_recorder.Get_All_Interactions();
float expected_probs[3] = { .1f, .1f, .1f };
this->Test_Interactions(interactions, 3, expected_probs);
}
// Test end-to-end using StringRecorder with no crash
template <class Exp>
void End_To_End(MwtExplorer<SimpleContext>& mwt, Exp& explorer, StringRecorder<SimpleContext>& recorder)
{
PRG::prg rand;
float rewards[10];
for (int i = 0; i < 10; i++)
{
vector<Feature> features;
for (int j = 0; j < 1000; j++)
{
features.push_back({ rand.Uniform_Unit_Interval(), j + 1 });
}
SimpleContext c(features);
mwt.Choose_Action(explorer, to_string(i), c);
rewards[i] = rand.Uniform_Unit_Interval();
}
recorder.Get_Recording();
}
template <class Ctx>
inline void Test_Interactions(vector<TestInteraction<Ctx>> interactions, int num_interactions_expected, float* probs_expected)
{
size_t num_interactions = interactions.size();
Assert::AreEqual(num_interactions_expected, (int)num_interactions);
for (size_t i = 0; i < num_interactions; i++)
{
Assert::AreEqual(probs_expected[i], interactions[i].Probability);
}
}
string Get_Unique_Key(u32 seed)
{
PRG::prg rg(seed);
std::ostringstream unique_key_container;
unique_key_container << rg.Uniform_Unit_Interval();
return unique_key_container.str();
}
};
}

Просмотреть файл

@ -1,184 +0,0 @@
#pragma once
#include "MWTExplorer.h"
#include "utility.h"
#include <iomanip>
#include <iostream>
#include <sstream>
using namespace MultiWorldTesting;
class TestContext
{
};
class TestVarContext : public TestContext, public IVariableActionContext
{
public:
TestVarContext(u32 num_actions)
{
m_num_actions = num_actions;
}
u32 Get_Number_Of_Actions()
{
return m_num_actions;
}
private:
u32 m_num_actions;
};
template <class Ctx>
struct TestInteraction
{
Ctx& Context;
u32 Action;
float Probability;
string Unique_Key;
};
template <class TContext>
class TestPolicy : public IPolicy<TContext>
{
public:
TestPolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
u32 Choose_Action(TContext& context)
{
return m_params % m_num_actions + 1; // action id is one-based
}
private:
int m_params;
int m_num_actions;
};
template <class TContext>
class TestScorer : public IScorer<TContext>
{
public:
TestScorer(int params, int num_actions, bool uniform = true) :
m_params(params), m_num_actions(num_actions), m_uniform(uniform)
{ }
vector<float> Score_Actions(TContext& context)
{
vector<float> scores;
if (m_uniform)
{
for (int i = 0; i < m_num_actions; i++)
{
scores.push_back((float)m_params);
}
}
else
{
for (int i = 0; i < m_num_actions; i++)
{
scores.push_back((float)m_params + i);
}
}
return scores;
}
private:
int m_params;
int m_num_actions;
bool m_uniform;
};
class FixedScorer : public IScorer<TestContext>
{
public:
FixedScorer(int num_actions, int value) :
m_num_actions(num_actions), m_value(value)
{ }
vector<float> Score_Actions(TestContext& context)
{
vector<float> scores;
for (int i = 0; i < m_num_actions; i++)
{
scores.push_back((float)m_value);
}
return scores;
}
private:
int m_num_actions;
int m_value;
};
class TestSimpleScorer : public IScorer<SimpleContext>
{
public:
TestSimpleScorer(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
vector<float> Score_Actions(SimpleContext& context)
{
vector<float> scores;
for (int i = 0; i < m_num_actions; i++)
{
scores.push_back((float)m_params);
}
return scores;
}
private:
int m_params;
int m_num_actions;
};
class TestSimplePolicy : public IPolicy<SimpleContext>
{
public:
TestSimplePolicy(int params, int num_actions) : m_params(params), m_num_actions(num_actions) { }
u32 Choose_Action(SimpleContext& context)
{
return m_params % m_num_actions + 1; // action id is one-based
}
private:
int m_params;
int m_num_actions;
};
class TestSimpleRecorder : public IRecorder<SimpleContext>
{
public:
virtual void Record(SimpleContext& context, u32 action, float probability, string unique_key)
{
m_interactions.push_back({ context, action, probability, unique_key });
}
vector<TestInteraction<SimpleContext>> Get_All_Interactions()
{
return m_interactions;
}
private:
vector<TestInteraction<SimpleContext>> m_interactions;
};
// Return action outside valid range
class TestBadPolicy : public IPolicy<TestContext>
{
public:
u32 Choose_Action(TestContext& context)
{
return 100;
}
};
template <class TContext>
class TestRecorder : public IRecorder<TContext>
{
public:
virtual void Record(TContext& context, u32 action, float probability, string unique_key)
{
m_interactions.push_back({ context, action, probability, unique_key });
}
vector<TestInteraction<TContext>> Get_All_Interactions()
{
return m_interactions;
}
private:
vector<TestInteraction<TContext>> m_interactions;
};

Просмотреть файл

@ -1,169 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="12.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{5AE3AA40-BEB0-4979-8166-3B885172C430}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>vw_explore_tests</RootNamespace>
<ProjectName>explore_tests</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v120</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
<UseOfMfc>false</UseOfMfc>
</PropertyGroup>
<PropertyGroup Condition="'$(Platform)'=='x64'">
<BoostIncludeDir>c:\boost\x64\include\boost-1_56</BoostIncludeDir>
<BoostLibDir>c:\boost\x64\lib</BoostLibDir>
<ZlibIncludeDir>..\..\..\zlib-1.2.8</ZlibIncludeDir>
<ZlibLibDir>$(ZlibIncludeDir)\contrib\vstudio\vc10\x64\ZlibStat$(Configuration)</ZlibLibDir>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" />
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<AdditionalIncludeDirectories>..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PreprocessorDefinitions>WIN32;_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<UseFullPaths>true</UseFullPaths>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
</Link>
<PreBuildEvent />
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<AdditionalIncludeDirectories>..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PreprocessorDefinitions>WIN32;_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<UseFullPaths>true</UseFullPaths>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalLibraryDirectories>$(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
</Link>
<PreBuildEvent />
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<AdditionalIncludeDirectories>..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PreprocessorDefinitions>WIN32;NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<UseFullPaths>true</UseFullPaths>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalLibraryDirectories>$(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
</Link>
<PreBuildEvent />
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>NotUsing</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<AdditionalIncludeDirectories>..\static;$(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
<PreprocessorDefinitions>WIN32;NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<UseFullPaths>true</UseFullPaths>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalLibraryDirectories>$(VCInstallDir)UnitTest\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
</Link>
<PreBuildEvent />
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="MWTExploreTests.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="MWTExploreTests.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

Просмотреть файл

@ -1,27 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="MWTExploreTests.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="MWTExploreTests.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

Просмотреть файл

@ -1,8 +0,0 @@
// stdafx.cpp : source file that includes just the standard includes
// vw_explore_tests.pch will be the pre-compiled header
// stdafx.obj will contain the pre-compiled type information
#include "stdafx.h"
// TODO: reference any additional headers you need in STDAFX.H
// and not in this file

Просмотреть файл

@ -1,16 +0,0 @@
// stdafx.h : include file for standard system include files,
// or project specific include files that are used frequently, but
// are changed infrequently
//
#pragma once
#include "targetver.h"
#include <math.h>
#include <fstream>
// Headers for CppUnitTest
#include "CppUnitTest.h"
// TODO: reference additional headers your program requires here
#define TEST_CPP

Просмотреть файл

@ -1,8 +0,0 @@
#pragma once
// Including SDKDDKVer.h defines the highest available Windows platform.
// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and
// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h.
#include <SDKDDKVer.h>