diff --git a/test/EndToEndOnlineTrainerTest.cs b/test/EndToEndOnlineTrainerTest.cs new file mode 100644 index 00000000..a8cea9f3 --- /dev/null +++ b/test/EndToEndOnlineTrainerTest.cs @@ -0,0 +1,183 @@ +using Microsoft.Research.MultiWorldTesting.ClientLibrary; +using Microsoft.Research.MultiWorldTesting.Contract; +using Microsoft.Research.MultiWorldTesting.ExploreLibrary; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Research.DecisionServiceTest +{ + [TestClass] + public class EndToEndTest : ProvisioningBaseTest + { + public EndToEndTest() + { + this.deleteOnCleanup = false; + } + + [TestMethod] + public void E2ERankerStochasticRewards() + { + // Create configuration for the decision service + float initialEpsilon = .5f; + + this.ConfigureDecisionService(trainArguments: "--cb_explore_adf --cb_type dr -q :: --epsilon 0.2", initialExplorationEpsilon: initialEpsilon); + + string settingsBlobUri = this.settingsUrl; + + float percentCorrect = UploadFoodContextData(settingsBlobUri, firstPass: true); + Assert.IsTrue(percentCorrect < initialEpsilon); + + percentCorrect = UploadFoodContextData(settingsBlobUri, firstPass: false); + Assert.IsTrue(percentCorrect > .8f); + } + + private float UploadFoodContextData(string settingsBlobUri, bool firstPass) + { + var serviceConfig = new DecisionServiceConfiguration(settingsBlobUri); + + if (firstPass) + { + serviceConfig.PollingForModelPeriod = TimeSpan.MinValue; + this.OnlineTrainerReset(); + } + + using (var service = DecisionService.Create(serviceConfig)) + { + if (!firstPass) + { + Thread.Sleep(10000); + } + + string uniqueKey = "scratch-key-gal"; + string[] locations = { "HealthyTown", "LessHealthyTown" }; + + var rg = new Random(uniqueKey.GetHashCode()); + + int numActions = 3; // ["Hamburger deal 1", "Hamburger deal 2" (better), "Salad deal"] + + var csv = new StringBuilder(); + + int counterCorrect = 0; + int counterTotal = 0; + + var header = "Location,Action,Reward"; + csv.AppendLine(header); + // number of iterations + for (int i = 0; i < 10000 * locations.Length; i++) + { + // randomly select a location + int iL = rg.Next(0, locations.Length); + string location = locations[iL]; + + DateTime timeStamp = DateTime.UtcNow; + string key = uniqueKey + Guid.NewGuid().ToString(); + + FoodContext currentContext = new FoodContext(); + currentContext.UserLocation = location; + currentContext.Actions = Enumerable.Range(1, numActions).ToArray(); + + int[] action = service.ChooseRanking(key, currentContext); + + counterTotal += 1; + + // We expect healthy town to get salad and unhealthy town to get the second burger (action 2) + if (location.Equals("HealthyTown") && action[0] == 3) + counterCorrect += 1; + else if (location.Equals("LessHealthyTown") && action[0] == 2) + counterCorrect += 1; + + var csvLocation = location; + var csvAction = action[0].ToString(); + + float reward = 0; + double currentRand = rg.NextDouble(); + if (location.Equals("HealthyTown")) + { + // for healthy town, buy burger 1 with probability 0.1, burger 2 with probability 0.15, salad with probability 0.6 + if ((action[0] == 1 && currentRand < 0.1) || + (action[0] == 2 && currentRand < 0.15) || + (action[0] == 3 && currentRand < 0.6)) + { + reward = 10; + } + } + else + { + // for unhealthy town, buy burger 1 with probability 0.4, burger 2 with probability 0.6, salad with probability 0.2 + if ((action[0] == 1 && currentRand < 0.4) || + (action[0] == 2 && currentRand < 0.6) || + (action[0] == 3 && currentRand < 0.2)) + { + reward = 10; + } + } + service.ReportReward(reward, key); + var newLine = string.Format("{0},{1},{2}", csvLocation, csvAction, "0"); + csv.AppendLine(newLine); + + System.Threading.Thread.Sleep(1); + + } + return (float)counterCorrect / counterTotal; + } + } + } + + public class FoodContext + { + public string UserLocation { get; set; } + + [JsonIgnore] + public int[] Actions { get; set; } + + [JsonProperty(PropertyName = "_multi")] + public FoodFeature[] ActionDependentFeatures + { + get + { + return this.Actions + .Select((a, i) => new FoodFeature(this.Actions.Length, i)) + .ToArray(); + } + } + + public static IReadOnlyCollection GetFeaturesFromContext(FoodContext context) + { + return context.ActionDependentFeatures; + } + } + + public class FoodFeature + { + public float[] Scores { get; set; } + + internal FoodFeature(int numActions, int index) + { + Scores = Enumerable.Repeat(0f, numActions).ToArray(); + Scores[index] = index + 1; + } + } + + class FoodRecorder : IRecorder + { + Dictionary keyToProb = new Dictionary(); + public float GetProb(string key) + { + return keyToProb[key]; + } + + public void Record(FoodContext context, int[] value, object explorerState, object mapperState, string uniqueKey) + { + keyToProb.Add(uniqueKey, ((EpsilonGreedyState)explorerState).Probability); + } + } +} + diff --git a/test/ProvisioningBaseTest.cs b/test/ProvisioningBaseTest.cs index 7d4aa248..3b9779c0 100644 --- a/test/ProvisioningBaseTest.cs +++ b/test/ProvisioningBaseTest.cs @@ -16,14 +16,15 @@ using System.Net; using System.Text; using System.Web; using Microsoft.Research.MultiWorldTesting.Contract; +using System.Threading; -namespace Microsoft.Research.DecisionService.Test +namespace Microsoft.Research.DecisionServiceTest { public class ProvisioningBaseTest { - private bool deleteOnCleanup; private JObject deploymentOutput; + protected bool deleteOnCleanup; protected string managementCenterUrl; protected string managementPassword; protected string onlineTrainerUrl; @@ -91,6 +92,7 @@ namespace Microsoft.Research.DecisionService.Test { wc.Headers.Add($"Authorization: {onlineTrainerToken}"); wc.DownloadString($"{onlineTrainerUrl}/reset"); + Thread.Sleep(TimeSpan.FromSeconds(3)); } } diff --git a/test/SimplePolicyTest.cs b/test/SimplePolicyTest.cs index 32eef260..112d808e 100644 --- a/test/SimplePolicyTest.cs +++ b/test/SimplePolicyTest.cs @@ -13,7 +13,7 @@ using System.Threading; using System.Threading.Tasks; using System.Web; -namespace Microsoft.Research.DecisionService.Test +namespace Microsoft.Research.DecisionServiceTest { [TestClass] public class SimplePolicyTestClass : ProvisioningBaseTest diff --git a/test/ds-provisioning.csproj b/test/ds-provisioning.csproj index 5c24209f..25372bf5 100644 --- a/test/ds-provisioning.csproj +++ b/test/ds-provisioning.csproj @@ -6,8 +6,8 @@ {26167826-AF8E-47D8-95BC-CA6B6CA9808F} Library Properties - Microsoft.Research.DecisionService.Test - Microsoft.Research.DecisionService.Test + Microsoft.Research.DecisionServiceTest + Microsoft.Research.DecisionServiceTest v4.5.2 512 {3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} @@ -231,6 +231,7 @@ +