diff --git a/Source/Tx.Core/Demultiplexor.cs b/Source/Tx.Core/Demultiplexor.cs index 23d002a..cea95b5 100644 --- a/Source/Tx.Core/Demultiplexor.cs +++ b/Source/Tx.Core/Demultiplexor.cs @@ -16,6 +16,8 @@ namespace System.Reactive { private readonly Dictionary> _outputs = new Dictionary>(); + private readonly Dictionary> _knownOutputMappings = new Dictionary>(); + public void OnCompleted() { foreach (var output in _outputs.Values.ToArray()) @@ -32,17 +34,22 @@ namespace System.Reactive } } - public void OnNext(object value) + public void OnNext(object inputObject) { - IObserver output; - if (_outputs.TryGetValue(value.GetType(), out output)) + var inputObjectType = inputObject.GetType(); + + if (!_knownOutputMappings.ContainsKey(inputObjectType)) { - output.OnNext(value); + _knownOutputMappings.Add(inputObjectType, new List()); + foreach (var type in GetTypes(inputObjectType).Where(type => _outputs.ContainsKey(type))) + { + _knownOutputMappings[inputObjectType].Add(type); + } } - if (_outputs.TryGetValue(value.GetType().BaseType, out output)) + foreach (var keyType in _knownOutputMappings[inputObjectType]) { - output.OnNext(value); + _outputs[keyType].OnNext(inputObject); } } @@ -54,16 +61,41 @@ namespace System.Reactive public IObservable GetObservable() { IObserver o; - if (!_outputs.TryGetValue(typeof (TOutput), out o)) + if (!_outputs.TryGetValue(typeof(TOutput), out o)) { o = new OutputSubject(); - _outputs.Add(typeof (TOutput), o); + _outputs.Add(typeof(TOutput), o); + RefreshKnownOutputMappings(typeof(TOutput)); } - var output = (IObservable) o; + var output = (IObservable)o; return output; } + private List GetTypes(Type inputType) + { + var typeList = new List(); + var temp = inputType; + while (temp.FullName != typeof(object).FullName) + { + typeList.Add(temp); + temp = temp.BaseType; + } + typeList.AddRange(inputType.GetInterfaces()); + return typeList; + } + + private void RefreshKnownOutputMappings(Type outputType) + { + foreach (var knownMappings in _knownOutputMappings) + { + if (GetTypes(knownMappings.Key).Contains(outputType) && !knownMappings.Value.Contains(outputType)) + { + knownMappings.Value.Add(outputType); + } + } + } + private class OutputSubject : ISubject, IDisposable { private readonly Subject _subject; @@ -95,7 +127,7 @@ namespace System.Reactive public void OnNext(object value) { - _subject.OnNext((T) value); + _subject.OnNext((T)value); } public IDisposable Subscribe(IObserver observer) diff --git a/Test/UnitTests/DemultiplexorTest.cs b/Test/UnitTests/DemultiplexorTest.cs index 9d14274..36c10b0 100644 --- a/Test/UnitTests/DemultiplexorTest.cs +++ b/Test/UnitTests/DemultiplexorTest.cs @@ -29,6 +29,82 @@ namespace Tests.Tx Assert.AreEqual(2, stringObserver.Count); } + [TestMethod] + public void GetObservableInheritenceTest1() + { + var itemA = new TestClassA(); + var itemB = new TestClassB(); + var itemC = new TestClassC(); + var interfaceObserver = new CountObserver(); + var itemAObserver = new CountObserver(); + var itemBObserver = new CountObserver(); + var itemCObserver = new CountObserver(); + + var demux = new Demultiplexor(); + + demux.GetObservable().Subscribe(interfaceObserver); + demux.GetObservable().Subscribe(itemAObserver); + demux.GetObservable().Subscribe(itemBObserver); + demux.GetObservable().Subscribe(itemCObserver); + + demux.OnNext(itemA); + demux.OnNext(itemB); + demux.OnNext(itemC); + + Assert.AreEqual(3, interfaceObserver.Count); + Assert.AreEqual(3, itemAObserver.Count); + Assert.AreEqual(2, itemBObserver.Count); + Assert.AreEqual(1, itemCObserver.Count); + } + + [TestMethod] + public void TestLateGetObservableRefreshesCache() + { + var itemA = new TestClassA(); + var itemB = new TestClassB(); + var itemC = new TestClassC(); + var interfaceObserver = new CountObserver(); + var itemAObserver = new CountObserver(); + var itemBObserver = new CountObserver(); + var itemCObserver = new CountObserver(); + + var demux = new Demultiplexor(); + + demux.GetObservable().Subscribe(itemAObserver); + demux.GetObservable().Subscribe(itemBObserver); + demux.GetObservable().Subscribe(itemCObserver); + + demux.OnNext(itemA); + demux.GetObservable().Subscribe(interfaceObserver); + demux.OnNext(itemB); + demux.OnNext(itemC); + + Assert.AreEqual(2, interfaceObserver.Count); + Assert.AreEqual(3, itemAObserver.Count); + Assert.AreEqual(2, itemBObserver.Count); + Assert.AreEqual(1, itemCObserver.Count); + } + + interface ITestClassA + { + int ValueA { get; set; } + } + + class TestClassA : ITestClassA + { + public int ValueA { get; set; } + } + + class TestClassB : TestClassA + { + public string ValueB { get; set; } + } + + class TestClassC : TestClassB + { + public double ValueC { get; set; } + } + class CountObserver : IObserver { int _count;