diff --git a/source/NSubstitute.Elevated.sln b/source/NSubstitute.Elevated.sln index 7e68ca2..357c361 100644 --- a/source/NSubstitute.Elevated.sln +++ b/source/NSubstitute.Elevated.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.27009.1 +VisualStudioVersion = 15.0.27019.1 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NSubstitute.Elevated", "NSubstitute.Elevated\NSubstitute.Elevated.csproj", "{771C49B1-4768-45FA-97BA-37B56268C534}" EndProject @@ -11,6 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NSubstitute.Elevated.Tests" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{36AEE027-40B5-4ECF-8A1B-4FCAE63C73B3}" ProjectSection(SolutionItems) = preProject + common.targets = common.targets ..\README.md = ..\README.md EndProjectSection EndProject @@ -20,6 +21,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SystemUnderTest", "..\tests EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DependentAssembly", "..\tests\Support\DependentAssembly\DependentAssembly.csproj", "{5F9C587F-9F8A-40E5-87CA-62C55481851C}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Unity.Core", "Unity.Core\Unity.Core.csproj", "{4483F618-9ADB-4A0B-A0D4-37EDB2593F06}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Unity.Core.Tests", "Unity.Core.Tests\Unity.Core.Tests.csproj", "{6B83FD54-CA1D-4E0A-A700-71AAD1992EF1}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -42,6 +47,14 @@ Global {5F9C587F-9F8A-40E5-87CA-62C55481851C}.Debug|Any CPU.Build.0 = Debug|Any CPU {5F9C587F-9F8A-40E5-87CA-62C55481851C}.Release|Any CPU.ActiveCfg = Release|Any CPU {5F9C587F-9F8A-40E5-87CA-62C55481851C}.Release|Any CPU.Build.0 = Release|Any CPU + {4483F618-9ADB-4A0B-A0D4-37EDB2593F06}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4483F618-9ADB-4A0B-A0D4-37EDB2593F06}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4483F618-9ADB-4A0B-A0D4-37EDB2593F06}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4483F618-9ADB-4A0B-A0D4-37EDB2593F06}.Release|Any CPU.Build.0 = Release|Any CPU + {6B83FD54-CA1D-4E0A-A700-71AAD1992EF1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6B83FD54-CA1D-4E0A-A700-71AAD1992EF1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6B83FD54-CA1D-4E0A-A700-71AAD1992EF1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6B83FD54-CA1D-4E0A-A700-71AAD1992EF1}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -51,6 +64,7 @@ Global {F19A2CFF-A4DE-4D84-8220-662EBA16283A} = {0ADFBFB2-69DD-42A9-959C-4B476863594D} {4C7110B5-C596-4AE0-A67F-0AEF0E3D016D} = {F19A2CFF-A4DE-4D84-8220-662EBA16283A} {5F9C587F-9F8A-40E5-87CA-62C55481851C} = {F19A2CFF-A4DE-4D84-8220-662EBA16283A} + {6B83FD54-CA1D-4E0A-A700-71AAD1992EF1} = {0ADFBFB2-69DD-42A9-959C-4B476863594D} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FD843EB6-8549-43F4-A30F-53A79FF71FA7} diff --git a/source/NSubstitute.Elevated/AssemblyInfo.cs b/source/NSubstitute.Elevated/AssemblyInfo.cs new file mode 100644 index 0000000..dc662b0 --- /dev/null +++ b/source/NSubstitute.Elevated/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("NSubstitute.Elevated.Tests")] diff --git a/source/NSubstitute.Elevated/ElevatedSubstituteManager.cs b/source/NSubstitute.Elevated/ElevatedSubstituteManager.cs index 306dfd0..2352088 100644 --- a/source/NSubstitute.Elevated/ElevatedSubstituteManager.cs +++ b/source/NSubstitute.Elevated/ElevatedSubstituteManager.cs @@ -3,11 +3,12 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; using NSubstitute.Core; -using NSubstitute.Elevated.Utilities; +using NSubstitute.Elevated.WeaverInternals; using NSubstitute.Exceptions; using NSubstitute.Proxies; using NSubstitute.Proxies.CastleDynamicProxy; using NSubstitute.Proxies.DelegateProxy; +using Unity.Core; namespace NSubstitute.Elevated { @@ -15,6 +16,7 @@ namespace NSubstitute.Elevated { readonly CallFactory m_CallFactory; readonly IProxyFactory m_DefaultProxyFactory = new ProxyFactory(new DelegateProxyFactory(), new CastleDynamicProxyFactory()); + readonly object[] k_MockedCtorParams = { new MockPlaceholderType() }; public ElevatedSubstituteManager(ISubstitutionContext substitutionContext) { @@ -23,14 +25,17 @@ namespace NSubstitute.Elevated object IProxyFactory.GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[] additionalInterfaces, object[] constructorArguments) { + // TODO: + // * new type MockCtorPlaceholder in elevated assy + // * generate new empty ctor that takes MockCtorPlaceholder in all mocked types + // * support ctor params. throw if foudn and not ForPartsOf. then ForPartsOf determines which ctor we use. + // * have a note about static ctors. because they are special, and do not support disposal, can't really mock them right. + // best for user to do mock/unmock of static ctors manually (i.e. move into StaticInit/StaticDispose and call directly from test code) + object proxy; + var substituteConfig = ElevatedSubstitutionContext.TryGetSubstituteConfig(callRouter); - var shouldForward = typeToProxy.IsInterface; - - // TEMP - shouldForward |= typeToProxy.FullName != "SystemUnderTest.SimpleClass"; - - if (shouldForward) + if (typeToProxy.IsInterface || substituteConfig == null) { proxy = m_DefaultProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments); } @@ -53,18 +58,18 @@ namespace NSubstitute.Elevated if (additionalInterfaces.Any()) throw new SubstituteException("Cannot add interfaces at runtime to patched types"); - // nsubstitute's dynamic proxy works on concrete classes by inheriting via a runtime-generated type, overriding - // virtuals with interceptor behavior. because the base is unmodified, it needs ctor params to be passed in - // for the proxy to pass to base. this in turn likely runs code, and we're not really working with an actual - // mock. - // - // elevated mocking via assembly patching, by contrast, lets us a) insert a new default ctor where missing, and - // b) bypass any existing default ctor code from executing. we end up with a true mock. therefore, it makes no - // sense to ever pass in ctor args, so this case becomes an exception. - if (constructorArguments.Any()) - throw new SubstituteException("Do not pass ctor args when substituting with elevated mocks"); + if (substituteConfig == SubstituteConfig.OverrideAllCalls) + { + // overriding all calls includes the ctor, so it makes no sense for the user to pass in ctor args + if (constructorArguments.Any()) + throw new SubstituteException("Do not pass ctor args when substituting with elevated mocks (or did you mean to use ForPartsOf?)"); - proxy = CreateProxy(typeToProxy, callRouter); + // but we use a ctor arg to select the special empty ctor that we patched in + constructorArguments = k_MockedCtorParams; + } + + proxy = Activator.CreateInstance(typeToProxy, constructorArguments); + GetRouterField(typeToProxy).SetValue(proxy, callRouter); } return proxy; @@ -90,16 +95,6 @@ namespace NSubstitute.Elevated })); } - object CreateProxy(Type typeToProxy, ICallRouter callRouter) - { - var field = GetRouterField(typeToProxy); - - var newInstance = Activator.CreateInstance(typeToProxy); - field.SetValue(newInstance, callRouter); - - return newInstance; - } - // called from patched assembly code via the PatchedAssemblyBridge. return true if the mock is handling the behavior. // false means that the original implementation should run. public bool TryMock(Type actualType, object instance, Type mockedReturnType, out object mockedReturnValue, MethodInfo method, Type[] methodGenericTypes, object[] args) @@ -109,7 +104,7 @@ namespace NSubstitute.Elevated if (callRouter != null) { - bool shouldCallOriginalMethod = false; + var shouldCallOriginalMethod = false; var call = m_CallFactory.Create(method, args, instance, () => shouldCallOriginalMethod = true); mockedReturnValue = callRouter.Route(call); @@ -126,10 +121,10 @@ namespace NSubstitute.Elevated // 2. support for struct instances (only possible to associate call routers with individual structs from the inside) // 3. is a simple way to check that a type has been patched // - FieldInfo GetStaticRouterField(Type type) => m_RouterStaticFieldCache.GetOrAdd(type, t => GetRouterField(t, "__mock__staticData", BindingFlags.Static)); - FieldInfo GetRouterField(Type type) => m_RouterFieldCache.GetOrAdd(type, t => GetRouterField(t, "__mock__data", BindingFlags.Instance)); + FieldInfo GetStaticRouterField(Type type) => m_RouterStaticFieldCache.GetOrAdd(type, t => GetRouterField(t, Weaver.MockInjector.InjectedMockStaticDataName, BindingFlags.Static)); + FieldInfo GetRouterField(Type type) => m_RouterFieldCache.GetOrAdd(type, t => GetRouterField(t, Weaver.MockInjector.InjectedMockDataName, BindingFlags.Instance)); - static FieldInfo GetRouterField(Type type, string fieldName, BindingFlags bindingFlags) + static FieldInfo GetRouterField(IReflect type, string fieldName, BindingFlags bindingFlags) { var field = type.GetField(fieldName, bindingFlags | BindingFlags.NonPublic); if (field == null) diff --git a/source/NSubstitute.Elevated/ElevatedSubstitutionContext.cs b/source/NSubstitute.Elevated/ElevatedSubstitutionContext.cs index 180b3f2..9d9e57a 100644 --- a/source/NSubstitute.Elevated/ElevatedSubstitutionContext.cs +++ b/source/NSubstitute.Elevated/ElevatedSubstitutionContext.cs @@ -3,13 +3,13 @@ using System.Collections.Generic; using JetBrains.Annotations; using NSubstitute.Core; using NSubstitute.Core.Arguments; -using NSubstitute.Elevated.Utilities; using NSubstitute.Exceptions; using NSubstitute.Routing; +using Unity.Core; namespace NSubstitute.Elevated { - // class motivation: + // motivation: // // 1. it's the clean way to hook in our own proxy factory to the nsub machinery // 2. provide access to the sub manager so patched assemblies can route hooked calls through nsub (the so-called 'elevated' mock part) @@ -19,12 +19,13 @@ namespace NSubstitute.Elevated readonly ISubstitutionContext m_Forwarder; readonly ISubstituteFactory m_ElevatedSubstituteFactory; + // ReSharper disable once MemberCanBePrivate.Global public ElevatedSubstitutionContext([NotNull] ISubstitutionContext forwarder) { m_Forwarder = forwarder; ElevatedSubstituteManager = new ElevatedSubstituteManager(this); m_ElevatedSubstituteFactory = new SubstituteFactory(this, - new CallRouterFactory(), ElevatedSubstituteManager, new CallRouterResolver()); + new ElevatedCallRouterFactory(), ElevatedSubstituteManager, new CallRouterResolver()); } public static IDisposable AutoHook() @@ -43,6 +44,23 @@ namespace NSubstitute.Elevated internal ElevatedSubstituteManager ElevatedSubstituteManager { get; } + class ElevatedCallRouterFactory : ICallRouterFactory + { + public ICallRouter Create(ISubstitutionContext substitutionContext, SubstituteConfig config) + => new ElevatedCallRouter(new SubstituteState(substitutionContext, config), substitutionContext, new RouteFactory()); + } + + class ElevatedCallRouter : CallRouter + { + public ElevatedCallRouter(ISubstituteState substituteState, ISubstitutionContext context, IRouteFactory routeFactory) + : base(substituteState, context, routeFactory) => SubstituteConfig = substituteState.SubstituteConfig; + + public SubstituteConfig SubstituteConfig { get; } + } + + internal static SubstituteConfig? TryGetSubstituteConfig(ICallRouter callRouter) + => (callRouter as ElevatedCallRouter)?.SubstituteConfig; + // this is the only one we're overriding for now, so we can hook our own factory in there. ISubstituteFactory ISubstitutionContext.SubstituteFactory => m_ElevatedSubstituteFactory; diff --git a/source/NSubstitute.Elevated/Extensions.cs b/source/NSubstitute.Elevated/Extensions.cs deleted file mode 100644 index 1e1990d..0000000 --- a/source/NSubstitute.Elevated/Extensions.cs +++ /dev/null @@ -1,32 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace NSubstitute.Elevated -{ - public static class Extensions - { - public static TValue GetOrAdd(this IDictionary @this, TKey key, Func createFunc) - { - if (@this.TryGetValue(key, out var found)) - return found; - - found = createFunc(key); - @this.Add(key, found); - return found; - } - - public static object GetDefaultValue(this Type @this) - { - object defaultValue = null; - if (@this.IsValueType && @this != typeof(void)) - defaultValue = Activator.CreateInstance(@this); - return defaultValue; - } - - public static bool NullOrEmpty(this IEnumerable @this) - { - return @this == null || !@this.Any(); - } - } -} diff --git a/source/NSubstitute.Elevated/NSubstitute.Elevated.csproj b/source/NSubstitute.Elevated/NSubstitute.Elevated.csproj index 7190b5e..321912d 100644 --- a/source/NSubstitute.Elevated/NSubstitute.Elevated.csproj +++ b/source/NSubstitute.Elevated/NSubstitute.Elevated.csproj @@ -8,11 +8,18 @@ - 11.0.0 + 11.1.0 2.0.3 + + 0.10.0-beta6 + + + + + diff --git a/source/NSubstitute.Elevated/PatchedAssemblyBridge.cs b/source/NSubstitute.Elevated/PatchedAssemblyBridge.cs index 5f6d949..6d0d379 100644 --- a/source/NSubstitute.Elevated/PatchedAssemblyBridge.cs +++ b/source/NSubstitute.Elevated/PatchedAssemblyBridge.cs @@ -3,11 +3,15 @@ using System.Diagnostics; using System.Reflection; using System.Runtime.CompilerServices; using NSubstitute.Core; +using Unity.Core; // this namespace contains types that must be public in order to be usable from patched assemblies, yet // we do not want used from normal client api namespace NSubstitute.Elevated.WeaverInternals { + // used when generating mocked default ctors + public class MockPlaceholderType {} + // important: keep all non-mscorlib types out of the public surface area of this class, so as to // avoid needing to add more references than NSubstitute.Elevated to the assembly during patching. diff --git a/source/NSubstitute.Elevated/Utilities.cs b/source/NSubstitute.Elevated/Utilities.cs index a14b806..ded80b6 100644 --- a/source/NSubstitute.Elevated/Utilities.cs +++ b/source/NSubstitute.Elevated/Utilities.cs @@ -1,17 +1,8 @@ using System; +using System.Collections.Generic; +using System.Linq; using JetBrains.Annotations; -namespace NSubstitute.Elevated.Utilities +namespace NSubstitute.Elevated { - public class DelegateDisposable : IDisposable - { - readonly Action m_DisposeAction; - - public DelegateDisposable([NotNull] Action disposeAction) => m_DisposeAction = disposeAction; - - public void Dispose() - { - m_DisposeAction(); - } - } } diff --git a/source/NSubstitute.Elevated/Weaver/CecilExtensions.cs b/source/NSubstitute.Elevated/Weaver/CecilExtensions.cs new file mode 100644 index 0000000..671b67e --- /dev/null +++ b/source/NSubstitute.Elevated/Weaver/CecilExtensions.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using JetBrains.Annotations; +using Mono.Cecil; +using Unity.Core; + +namespace NSubstitute.Elevated.Weaver +{ + public enum IncludeNested { No, Yes } + + public static class CecilExtensions + { + [NotNull] + public static IEnumerable SelectTypes([NotNull] this AssemblyDefinition @this, IncludeNested includeNested) + { + var types = @this.Modules.SelectMany(m => m.Types); + if (includeNested == IncludeNested.Yes) + types = types.SelectMany(t => t.NestedTypes.Append(t)); + return types; + } + + public static int InheritanceChainLength([NotNull] this TypeReference @this) + { + if (@this.DeclaringType == null) + return 0; + + var baseType = @this.Resolve().BaseType; + if (baseType == null) + return 1; + + return 1 + InheritanceChainLength(baseType); + } + } +} diff --git a/source/NSubstitute.Elevated/Weaver/ElevatedWeaver.cs b/source/NSubstitute.Elevated/Weaver/ElevatedWeaver.cs new file mode 100644 index 0000000..a221703 --- /dev/null +++ b/source/NSubstitute.Elevated/Weaver/ElevatedWeaver.cs @@ -0,0 +1,127 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reflection; +using JetBrains.Annotations; +using Mono.Cecil; +using Unity.Core; + +namespace NSubstitute.Elevated.Weaver +{ + public enum PatchTestAssembly { No, Yes } + + public static class ElevatedWeaver + { + const string k_PatchBackupExtension = ".orig"; + + public static string GetPatchBackupPathFor(string path) + => path + k_PatchBackupExtension; + + public static IReadOnlyCollection PatchAllDependentAssemblies( + [NotNull] string testAssemblyPath, + PatchTestAssembly patchTestAssembly = PatchTestAssembly.No) // typically we don't want to patch the test assembly itself, only the systems under test + { + var testAssemblyFolder = Path.GetDirectoryName(testAssemblyPath); + if (testAssemblyFolder.IsNullOrEmpty()) + throw new Exception("Unable to find folder for test assembly"); + testAssemblyFolder = Path.GetFullPath(testAssemblyFolder); + + // scope + { + var thisAssemblyFolder = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + if (thisAssemblyFolder.IsNullOrEmpty()) + throw new Exception("Can only patch assemblies on disk"); + thisAssemblyFolder = Path.GetFullPath(thisAssemblyFolder); + + // keep things really simple, at least for now + if (string.Compare(testAssemblyFolder, thisAssemblyFolder, StringComparison.OrdinalIgnoreCase) != 0) + throw new Exception("All assemblies must be in the same folder"); + } + + var nsubElevatedPath = Path.Combine(testAssemblyFolder, "NSubstitute.Elevated.dll"); + using (var nsubElevatedAssembly = AssemblyDefinition.ReadAssembly(nsubElevatedPath)) + { + var mockInjector = new MockInjector(nsubElevatedAssembly); + var toProcess = new List { Path.GetFullPath(testAssemblyPath) }; + var patchResults = new Dictionary(StringComparer.OrdinalIgnoreCase); + + for (var toProcessIndex = 0; toProcessIndex < toProcess.Count; ++toProcessIndex) + { + var assemblyToPatchPath = toProcess[toProcessIndex]; + if (patchResults.ContainsKey(assemblyToPatchPath)) + continue; + + if (!Path.IsPathRooted(assemblyToPatchPath)) + throw new Exception($"Unexpected non-rooted assembly path '{assemblyToPatchPath}'"); + + using (var assemblyToPatch = AssemblyDefinition.ReadAssembly(assemblyToPatchPath)) + { + foreach (var referencedAssembly in assemblyToPatch.Modules.SelectMany(m => m.AssemblyReferences)) + { + // only patch dll's we "own", that are in the same folder as the test assembly + var foundPath = Path.Combine(testAssemblyFolder, referencedAssembly.Name + ".dll"); + + if (File.Exists(foundPath)) + toProcess.Add(foundPath); + else if (!patchResults.ContainsKey(referencedAssembly.Name)) + patchResults.Add(referencedAssembly.Name, new PatchResult(referencedAssembly.Name, null, PatchState.IgnoredOutsideAllowedPaths)); + } + + PatchResult patchResult; + + if (toProcessIndex == 0 && patchTestAssembly == PatchTestAssembly.No) + patchResult = new PatchResult(assemblyToPatchPath, null, PatchState.IgnoredTestAssembly); + else if (MockInjector.IsPatched(assemblyToPatch)) + patchResult = new PatchResult(assemblyToPatchPath, null, PatchState.AlreadyPatched); + else + { + mockInjector.Patch(assemblyToPatch); + + // atomic write of file with backup + var tmpPath = assemblyToPatchPath + ".tmp"; + File.Delete(tmpPath); + assemblyToPatch.Write(tmpPath);//$$$$, new WriterParameters { WriteSymbols = true }); // getting exception, haven't looked into it yet + assemblyToPatch.Dispose(); + var originalPath = GetPatchBackupPathFor(assemblyToPatchPath); + File.Replace(tmpPath, assemblyToPatchPath, originalPath); + // $$$ TODO: move pdb file too + + patchResult = new PatchResult(assemblyToPatchPath, originalPath, PatchState.Patched); + } + + patchResults.Add(assemblyToPatchPath, patchResult); + } + } + + return patchResults.Values; + } + } + } + + public enum PatchState + { + GeneralFailure, // something else went wrong + IgnoredTestAssembly, // don't patch the test assembly itself, as we're requiring that to always be separate from the systems under test + IgnoredOutsideAllowedPaths, // don't want to patch things that are not "ours" + //AlreadyPatchedOld, // assy already patched against an older set of tooling TODO: implement + AlreadyPatched, // assy already patched against current tooling + Patched, // assy patched and old one backed up + } + + public struct PatchResult + { + public string Path; + public string OriginalPath; + public PatchState PatchState; + + [DebuggerStepThrough] + public PatchResult(string path, string originalPath, PatchState patchState) + { + Path = path; + OriginalPath = originalPath; + PatchState = patchState; + } + } +} diff --git a/source/NSubstitute.Elevated/Weaver/MockInjector.cs b/source/NSubstitute.Elevated/Weaver/MockInjector.cs new file mode 100644 index 0000000..24060d1 --- /dev/null +++ b/source/NSubstitute.Elevated/Weaver/MockInjector.cs @@ -0,0 +1,192 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Security.Policy; +using Mono.Cecil; +using Mono.Cecil.Cil; +using Mono.Cecil.Rocks; +using NSubstitute.Elevated.WeaverInternals; +using Unity.Core; +using Assembly = System.Reflection.Assembly; +using AssemblyMetadataAttribute = System.Reflection.AssemblyMetadataAttribute; + +namespace NSubstitute.Elevated.Weaver +{ + class MockInjector + { + static readonly string k_MarkAsPatchedKey, k_MarkAsPatchedValue; + + readonly TypeDefinition m_MockPlaceholderType; + readonly MethodDefinition m_PatchedAssemblyBridgeTryMock; + + public const string InjectedMockStaticDataName = "__mock__staticData", InjectedMockDataName = "__mock__data"; + + static MockInjector() + { + k_MarkAsPatchedKey = Assembly.GetExecutingAssembly().GetName().Name; + + var assemblyHash = Assembly.GetExecutingAssembly().Evidence.GetHostEvidence(); + if (assemblyHash == null) + throw new Exception("Assembly not stamped with a hash"); + + k_MarkAsPatchedValue = assemblyHash.SHA1.ToHexString(); + } + + public MockInjector(AssemblyDefinition nsubElevatedAssembly) + { + m_MockPlaceholderType = nsubElevatedAssembly.MainModule + .GetType(typeof(MockPlaceholderType).FullName); + + m_PatchedAssemblyBridgeTryMock = nsubElevatedAssembly.MainModule + .GetType(typeof(PatchedAssemblyBridge).FullName) + .Methods.Single(m => m.Name == nameof(PatchedAssemblyBridge.TryMock)); + } + + public void Patch(AssemblyDefinition assembly) + { + // patch all types + + var typesToProcess = assembly + .SelectTypes(IncludeNested.Yes) + .OrderBy(t => t.InheritanceChainLength()) // process base classes first + .ToList(); // copy to a list in case patch work we do would invalidate the enumerator + + foreach (var type in typesToProcess) + Patch(type); + + // add an attr to mark the assembly as patched + + var mainModule = assembly.MainModule; + var types = mainModule.TypeSystem; + + var metadataAttrName = typeof(AssemblyMetadataAttribute); + var metadataAttrType = new TypeReference(metadataAttrName.Namespace, metadataAttrName.Name, mainModule, types.CoreLibrary); + var metadataAttrCtor = new MethodReference(".ctor", types.Void, metadataAttrType) { HasThis = true }; + metadataAttrCtor.Parameters.Add(new ParameterDefinition(types.String)); + metadataAttrCtor.Parameters.Add(new ParameterDefinition(types.String)); + + var metadataAttr = new CustomAttribute(metadataAttrCtor); + metadataAttr.ConstructorArguments.Add(new CustomAttributeArgument(types.String, k_MarkAsPatchedKey)); + metadataAttr.ConstructorArguments.Add(new CustomAttributeArgument(types.String, k_MarkAsPatchedValue)); + + assembly.CustomAttributes.Add(metadataAttr); + } + + public static bool IsPatched(AssemblyDefinition assembly) + { + return assembly.CustomAttributes.Any(a => + a.AttributeType.FullName == typeof(AssemblyMetadataAttribute).FullName && + a.ConstructorArguments.Count == 2 && + a.ConstructorArguments[0].Value as string == k_MarkAsPatchedKey && + a.ConstructorArguments[1].Value as string == k_MarkAsPatchedValue); + } + + public static bool IsPatched(string assemblyPath) + { + using (var assembly = AssemblyDefinition.ReadAssembly(assemblyPath)) + return IsPatched(assembly); + } + + void Patch(TypeDefinition type) + { + if (type.IsInterface) + return; + if (type.IsNestedPrivate) + return; + if (type.Name == "") + return; + if (type.CustomAttributes.Any(a => + a.AttributeType.FullName == typeof(CompilerGeneratedAttribute).FullName || + a.AttributeType.FullName == typeof(StructLayoutAttribute).FullName)) + return; + + try + { + var patched = false; + foreach (var method in type.Methods) + { + if (Patch(method)) + patched = true; + } + + if (patched) + { + void AddField(string fieldName, FieldAttributes fieldAttributes) + { + type.Fields.Add(new FieldDefinition(fieldName, + FieldAttributes.Private | FieldAttributes.NotSerialized | fieldAttributes, + type.Module.TypeSystem.Object)); + } + + AddField(InjectedMockStaticDataName, FieldAttributes.Static); + AddField(InjectedMockDataName, 0); + + AddMockCtor(type); + } + } + catch (Exception e) + { + throw new Exception($"Internal error during mock injection into type {type.FullName}", e); + } + } + + public static bool IsPatched(TypeDefinition type) + { + var mockStaticField = type.Fields.SingleOrDefault(f => f.Name == InjectedMockStaticDataName); + var mockField = type.Fields.SingleOrDefault(f => f.Name == InjectedMockDataName); + if ((mockStaticField != null) != (mockField != null)) + throw new Exception("Unexpected mismatch between static and instance mock injected fields"); + + return mockStaticField != null; + } + + void AddMockCtor(TypeDefinition type) + { + var ctor = new MethodDefinition(".ctor", + MethodAttributes.Public | MethodAttributes.RTSpecialName | MethodAttributes.SpecialName | MethodAttributes.HideBySig, + type.Module.TypeSystem.Void) + { + IsManaged = true, + DeclaringType = type, + HasThis = true, + }; + ctor.Parameters.Add(new ParameterDefinition(type.Module.ImportReference(m_MockPlaceholderType))); + + var body = ctor.Body; + body.Instructions.Clear(); + + var il = body.GetILProcessor(); + + var baseCtor = (MethodReference)type + .BaseType.Resolve() + .GetConstructors() + .SingleOrDefault(candidate => candidate.Parameters.SequenceEqual(ctor.Parameters)); + if (baseCtor != null) + { + if (type.BaseType.IsGenericInstance) + baseCtor = new MethodReference(baseCtor.Name, baseCtor.ReturnType, type.BaseType) { HasThis = baseCtor.HasThis }; + else if (baseCtor.Module != type.Module) + baseCtor = type.Module.ImportReference(baseCtor); + + il.Append(il.Create(OpCodes.Ldarg_0)); + il.Append(il.Create(OpCodes.Call, baseCtor)); + } + + il.Append(il.Create(OpCodes.Ret)); + + type.Methods.Add(ctor); + } + + bool Patch(MethodDefinition method) + { + if (method.IsCompilerControlled || method.IsConstructor || method.IsAbstract) + return false; + + // $$$ DOWIT + + return true; + } + } +} diff --git a/source/Unity.Core.Tests/DictionaryExtensionsTests.cs b/source/Unity.Core.Tests/DictionaryExtensionsTests.cs new file mode 100644 index 0000000..a8c4faa --- /dev/null +++ b/source/Unity.Core.Tests/DictionaryExtensionsTests.cs @@ -0,0 +1,45 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Shouldly; + +namespace Unity.Core.Tests +{ + [TestFixture] + public class DictionaryExtensionsTests + { + [Test] + public void OrEmpty_NonNullInput_ReturnsInput() + { + var dictionary = new Dictionary {[0] = "zero" }; + + dictionary.OrEmpty().ShouldBe(dictionary); + } + + [Test] + public void OrEmpty_NullInput_ReturnsEmpty() + { + IReadOnlyDictionary dictionary = null; + + dictionary.OrEmpty().ShouldBeEmpty(); + } + + [Test] + public void GetValueOr_Found_ReturnsFound() + { + var dictionary = new Dictionary {[1] = "one" }; + + dictionary.GetValueOr(1).ShouldBe("one"); + dictionary.GetValueOr(1, "two").ShouldBe("one"); + } + + [Test] + public void GetValueOr_NotFound_ReturnsDefault() + { + var dictionary = new Dictionary {["one"] = 1 }; + + dictionary.GetValueOr("two").ShouldBe(0); + dictionary.GetValueOr("two", 2).ShouldBe(2); + } + } +} diff --git a/source/Unity.Core.Tests/DiffUtilsTests.cs b/source/Unity.Core.Tests/DiffUtilsTests.cs new file mode 100644 index 0000000..d95549e --- /dev/null +++ b/source/Unity.Core.Tests/DiffUtilsTests.cs @@ -0,0 +1,68 @@ +using System.Linq; +using NUnit.Framework; +using Shouldly; +using Unity.Core; + +namespace Unity.Core.Tests +{ + [TestFixture] + public class DiffUtilsTests + { + [Test] + public void IsDiff_ValidLfDiff_ReturnsTrue() + { + var diffText = new[] + { + "--- a/cppupdatr/Refactor/MoveFile.cs", + "+++ b/cppupdatr/Refactor/MoveFile.cs", + "@@ -1,6 +1,7 @@", + }.StringJoin('\n'); + + DiffUtils.IsDiff(diffText).ShouldBeTrue(); + } + + [Test] + public void IsDiff_ValidCrLfDiff_ReturnsTrue() + { + var diffText = new[] + { + "--- a/cppupdatr/Refactor/MoveFile.cs", + "+++ b/cppupdatr/Refactor/MoveFile.cs", + "@@ -1,6 +1,7 @@", + }.StringJoin("\r\n"); + + DiffUtils.IsDiff(diffText).ShouldBeTrue(); + } + + [Test] + public void IsDiff_EmptyDiff_ReturnsFalse() + { + DiffUtils.IsDiff("").ShouldBeFalse(); + } + + [Test] + public void IsDiff_BrokenDiff_ReturnsFalse() + { + var diffText = new[] + { + "--- a/cppupdatr/Refactor/MoveFile.cs", + " +++ b/cppupdatr/Refactor/MoveFile.cs", + "@@ -1,6 +1,7 @@" + }.StringJoin('\n'); + + DiffUtils.IsDiff(diffText).ShouldBeFalse(); + } + + [Test] + public void IsDiff_IncompleteDiff_ReturnsFalse() + { + var diffText = new[] + { + "--- a/cppupdatr/Refactor/MoveFile.cs", + "+++ b/cppupdatr/Refactor/MoveFile.cs", + }.StringJoin('\n'); + + DiffUtils.IsDiff(diffText).ShouldBeFalse(); + } + } +} diff --git a/source/Unity.Core.Tests/EnumerableExtensionsTests.cs b/source/Unity.Core.Tests/EnumerableExtensionsTests.cs new file mode 100644 index 0000000..2733cbc --- /dev/null +++ b/source/Unity.Core.Tests/EnumerableExtensionsTests.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; +using Shouldly; + +namespace Unity.Core.Tests +{ + [TestFixture] + public class EnumerableExtensionsTests + { + [Test] + public void WhereNotNull_ItemsWithNulls_ReturnsFilteredForNull() + { + var dummy1 = Enumerable.Empty(); + var dummy2 = new Exception(); + var enumerable = new object[] { null, "abc", dummy1, dummy2, null, null, "ghi" }; + + enumerable.WhereNotNull().ShouldBe(new object[] { "abc", dummy1, dummy2, "ghi" }); + } + + [Test] + public void WhereNotNull_Empty_ReturnsEmpty() + { + var enumerable = Enumerable.Empty(); + + enumerable.WhereNotNull().ShouldBeEmpty(); + } + + [Test] + public void WhereNotNull_AllNulls_ReturnsEmpty() + { + var enumerable = new object[] { null, null, null }; + + enumerable.WhereNotNull().ShouldBeEmpty(); + } + + [Test] + public void OrEmpty_NonNullInput_ReturnsInput() + { + var enumerable = new string[0]; + + enumerable.OrEmpty().ShouldBe(enumerable); + } + + [Test] + public void OrEmpty_NullInput_ReturnsEmpty() + { + IEnumerable enumerable = null; + + enumerable.OrEmpty().ShouldBeEmpty(); + } + + [Test] + public void ToDictionary_Tuples_ReturnsMappedDictionary() + { + var items = new[] { (1, "one"), (2, "two") }; + var dictionary = items.ToDictionary(); + + dictionary[1].ShouldBe("one"); + dictionary[2].ShouldBe("two"); + } + + [Test] + public void ToDictionary_TuplesWithDups_Throws() + { + var items = new[] { (1, "one"), (1, "two") }; + Should.Throw(() => items.ToDictionary()); + } + } +} diff --git a/source/Unity.Core.Tests/ExtensionsTests.cs b/source/Unity.Core.Tests/ExtensionsTests.cs new file mode 100644 index 0000000..4d99a5e --- /dev/null +++ b/source/Unity.Core.Tests/ExtensionsTests.cs @@ -0,0 +1,85 @@ +using System; +using System.Linq; +using NUnit.Framework; +using Shouldly; + +// ReSharper disable AssignNullToNotNullAttribute +// ReSharper disable ExpressionIsAlwaysNull + +namespace Unity.Core.Tests +{ + [TestFixture] + public class RefTypeExtensionsTests + { + [Test] + public void WrapEnumerables_NonNullInput_ReturnsInputWrappedInEnumerable() + { + const string item = "test"; + + var enumerable = item.WrapInEnumerable(); + enumerable.ShouldBe(new[] { item }); + + enumerable = item.WrapInEnumerableOrEmpty(); + enumerable.ShouldBe(new[] { item }); + } + + [Test] + public void WrapEnumerable_NullInput_ReturnsNullWrappedInEnumerable() + { + string item = null; + var enumerable = item.WrapInEnumerable(); + enumerable.ShouldBe(new[] { item }); + } + + [Test] + public void WrapEnumerableOrEmpty_NullInput_ReturnsEmptyEnumerable() + { + string item = null; + var enumerable = item.WrapInEnumerableOrEmpty(); + enumerable.ShouldBe(Enumerable.Empty()); + } + } + + [TestFixture] + public class ComparableExtensionsTests + { + [Test] + public void Clamp_BadRange_ShouldThrow() + { + Should.Throw(() => 1.Clamp (2, 1)); + Should.Throw(() => 'a'.Clamp('z', 'y')); + } + + [Test] + public void Clamp_InBounds_ReturnsValue() + { + 5.Clamp (2, 10).ShouldBe(5); + 3.14.Clamp(3, 6).ShouldBe(3.14); + 'b'.Clamp('a', 'z').ShouldBe('b'); + "abc".Clamp("a", "b").ShouldBe("abc"); + } + + [Test] + public void Clamp_OutOfBounds_ReturnsClampedValue() + { + 15.Clamp (3, 12).ShouldBe(12); + (-5).Clamp(-2, 4).ShouldBe(-2); + + 3.14.Clamp(3.2, 4.3).ShouldBe(3.2); + (-3.24).Clamp(-2.1, 1.5).ShouldBe(-2.1); + + 'b'.Clamp('d', 'z').ShouldBe('d'); + 'f'.Clamp('a', 'c').ShouldBe('c'); + + "abc".Clamp("bde", "cde").ShouldBe("bde"); + "hi".Clamp("abc", "foo").ShouldBe("foo"); + } + + [Test] + public void Clamp_Integer_ReturnsInclusiveClampedValue() + { + 5.Clamp (0, 5).ShouldBe(5); + 5.Clamp (0, 5).ShouldNotBe(4); + } + } +} diff --git a/source/Unity.Core.Tests/StringExtensionsTests.cs b/source/Unity.Core.Tests/StringExtensionsTests.cs new file mode 100644 index 0000000..f413d3d --- /dev/null +++ b/source/Unity.Core.Tests/StringExtensionsTests.cs @@ -0,0 +1,140 @@ +using System; +using NUnit.Framework; +using Shouldly; + +namespace Unity.Core.Tests +{ + [TestFixture] + public class StringExtensionsTests + { + [Test] + public void Left_InBounds_ReturnsSubstring() + { + "".Left(0).ShouldBe(""); + "abc".Left(2).ShouldBe("ab"); + "abc".Left(0).ShouldBe(""); + } + + [Test] + public void Left_OutOfBounds_ClampsProperly() + { + "".Left(10).ShouldBe(""); + "abc".Left(10).ShouldBe("abc"); + } + + [Test] + public void Left_BadInput_Throws() + { + // ReSharper disable once AssignNullToNotNullAttribute + Should.Throw(() => ((string)null).Left(1)); + Should.Throw(() => "abc".Left(-1)); + } + + [Test] + public void Mid_InBounds_ReturnsSubstring() + { + "".Mid(0, 0).ShouldBe(""); + "abc".Mid(0, 3).ShouldBe("abc"); + "abc".Mid(0).ShouldBe("abc"); + "abc".Mid(0, -2).ShouldBe("abc"); + "abc".Mid(1, 1).ShouldBe("b"); + "abc".Mid(3, 0).ShouldBe(""); + "abc".Mid(0, 0).ShouldBe(""); + } + + [Test] + public void Mid_OutOfBounds_ClampsProperly() + { + "".Mid(10, 5).ShouldBe(""); + "abc".Mid(0, 10).ShouldBe("abc"); + "abc".Mid(1, 10).ShouldBe("bc"); + "abc".Mid(10, 5).ShouldBe(""); + } + + [Test] + public void Mid_BadInput_Throws() + { + // ReSharper disable once AssignNullToNotNullAttribute + Should.Throw(() => ((string)null).Mid(1, 2)); + Should.Throw(() => "abc".Mid(-1)); + } + + [Test] + public void Right_InBounds_ReturnsSubstring() + { + "".Right(0).ShouldBe(""); + "abc".Right(2).ShouldBe("bc"); + "abc".Right(0).ShouldBe(""); + } + + [Test] + public void Right_OutOfBounds_ClampsProperly() + { + "".Right(10).ShouldBe(""); + "abc".Right(10).ShouldBe("abc"); + } + + [Test] + public void Right_BadInput_Throws() + { + // ReSharper disable once AssignNullToNotNullAttribute + Should.Throw(() => ((string)null).Right(1)); + Should.Throw(() => "abc".Right(-1)); + } + + [Test] + public void StringJoin_WithEmpty_ReturnsEmptyString() + { + var enumerable = new object[0]; + + enumerable.StringJoin(", ").ShouldBe(""); + enumerable.StringJoin(';').ShouldBe(""); + enumerable.StringJoin(o => o, ", ").ShouldBe(""); + enumerable.StringJoin(o => o, ';').ShouldBe(""); + } + + [Test] + public void StringJoin_WithSingle_ReturnsNoSeparators() + { + var enumerable = new[] { "abc" }; + + enumerable.StringJoin(", ").ShouldBe("abc"); + enumerable.StringJoin(';').ShouldBe("abc"); + enumerable.StringJoin(o => o, ", ").ShouldBe("abc"); + enumerable.StringJoin(o => o, ';').ShouldBe("abc"); + } + + [Test] + public void StringJoin_WithMultiple_ReturnsJoined() + { + var enumerable = new object[] { "abc", 0b111001, -14, 'z' }; + + enumerable.StringJoin(" ==> ").ShouldBe("abc ==> 57 ==> -14 ==> z"); + enumerable.StringJoin('\n').ShouldBe("abc\n57\n-14\nz"); + enumerable.StringJoin(o => o, " <> ").ShouldBe("abc <> 57 <> -14 <> z"); + enumerable.StringJoin(o => o, ';').ShouldBe("abc;57;-14;z"); + } + + [Test] + public void StringJoin_WithSelectorAndSimpleEnumerable_ReturnsSelectedJoined() + { + var enumerable = new[] { "hi", "there", "this", "", "is", "some", "stuff" }; + + int Selector(string value) { return value.Length; } + + enumerable.StringJoin(Selector, ", ").ShouldBe("2, 5, 4, 0, 2, 4, 5"); + enumerable.StringJoin(Selector, ';').ShouldBe("2;5;4;0;2;4;5"); + } + + [Test] + public void StringJoin_WithSelectorAndComplexEnumerable_ReturnsSelectedJoined() + { + var enumerable = new object[] { "abc", 123, null, ("hi", 1.23) }; + + string Selector(object value) { return value?.GetType().Name ?? "(null)"; } + + enumerable.StringJoin(Selector, " ** ").ShouldBe("String ** Int32 ** (null) ** ValueTuple`2"); + enumerable.StringJoin(Selector, '?').ShouldBe("String?Int32?(null)?ValueTuple`2"); + } + } +} diff --git a/source/Unity.Core.Tests/Unity.Core.Tests.csproj b/source/Unity.Core.Tests/Unity.Core.Tests.csproj new file mode 100644 index 0000000..4b20e57 --- /dev/null +++ b/source/Unity.Core.Tests/Unity.Core.Tests.csproj @@ -0,0 +1,25 @@ + + + + + + net461 + + + + + 11.1.0 + + + 3.8.1 + + + 2.8.3 + + + + + + + + diff --git a/source/Unity.Core/ConsoleUtils.cs b/source/Unity.Core/ConsoleUtils.cs new file mode 100644 index 0000000..b78af19 --- /dev/null +++ b/source/Unity.Core/ConsoleUtils.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; + +namespace Unity.Core +{ + public static class Stdin + { + public static IEnumerable SelectLines() + { + for (;;) + { + var line = Console.ReadLine(); + if (line == null) + yield break; + + yield return line; + } + } + } +} diff --git a/source/Unity.Core/DictionaryExtensions.cs b/source/Unity.Core/DictionaryExtensions.cs new file mode 100644 index 0000000..2881b04 --- /dev/null +++ b/source/Unity.Core/DictionaryExtensions.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using JetBrains.Annotations; + +namespace Unity.Core +{ + public class ReadOnlyDictionary + { + class EmptyDictionary : IReadOnlyDictionary + { + public static readonly IReadOnlyDictionary instance = new EmptyDictionary(); + + public IEnumerator> GetEnumerator() => Enumerable.Empty>().GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public int Count => 0; + public bool ContainsKey(TKey key) => false; + public bool TryGetValue(TKey key, out TValue value) { value = default(TValue); return false; } + public TValue this[TKey key] => throw new KeyNotFoundException(); + public IEnumerable Keys => Enumerable.Empty(); + public IEnumerable Values => Enumerable.Empty(); + } + + public static IReadOnlyDictionary Empty() + => EmptyDictionary.instance; + } + + public static class Dictionary + { + [NotNull] + public static Dictionary Create(params(TKey key, TValue value)[] items) + => items.ToDictionary(); + } + + public static class DictionaryExtensions + { + [NotNull] + public static IReadOnlyDictionary OrEmpty([CanBeNull] this IReadOnlyDictionary @this) + => @this ?? ReadOnlyDictionary.Empty(); + + public static TValue GetValueOr([NotNull] this IReadOnlyDictionary @this, TKey key, TValue defaultValue = default(TValue)) + => @this.TryGetValue(key, out var value) ? value : defaultValue; + + public static TValue GetOrAdd([NotNull] this IDictionary @this, TKey key, [NotNull] Func createFunc) + { + if (@this.TryGetValue(key, out var found)) + return found; + + found = createFunc(key); + @this.Add(key, found); + return found; + } + } +} diff --git a/source/Unity.Core/DiffUtils.cs b/source/Unity.Core/DiffUtils.cs new file mode 100644 index 0000000..4a76f48 --- /dev/null +++ b/source/Unity.Core/DiffUtils.cs @@ -0,0 +1,18 @@ +using System.Text.RegularExpressions; + +namespace Unity.Core +{ + public static class DiffUtils + { + public static bool IsDiff(string candidate) + { + const string detectDiffPattern = @"(?mx) + ^ + ---\ [^\n]+\n + \+\+\+\ [^\n]+\n + @@\ "; + + return Regex.IsMatch(candidate, detectDiffPattern); + } + } +} diff --git a/source/Unity.Core/DisposableUtils.cs b/source/Unity.Core/DisposableUtils.cs new file mode 100644 index 0000000..46bac3b --- /dev/null +++ b/source/Unity.Core/DisposableUtils.cs @@ -0,0 +1,13 @@ +using System; +using JetBrains.Annotations; + +namespace Unity.Core +{ + public class DelegateDisposable : IDisposable + { + readonly Action m_DisposeAction; + + public DelegateDisposable([NotNull] Action disposeAction) => m_DisposeAction = disposeAction; + public void Dispose() => m_DisposeAction(); + } +} diff --git a/source/Unity.Core/EnumerableExtensions.cs b/source/Unity.Core/EnumerableExtensions.cs new file mode 100644 index 0000000..51aebbb --- /dev/null +++ b/source/Unity.Core/EnumerableExtensions.cs @@ -0,0 +1,47 @@ +using System.Collections.Generic; +using System.Linq; +using JetBrains.Annotations; +using Enumerable = System.Linq.Enumerable; + +namespace Unity.Core +{ + public static class EnumerableExtensions + { + [NotNull] + public static IEnumerable WhereNotNull([NotNull] this IEnumerable @this) where T : class + => @this.Where(item => !(item is null)); + + [NotNull] + public static IEnumerable OrEmpty([CanBeNull] this IEnumerable @this) + => @this ?? Enumerable.Empty(); + + [NotNull] + public static HashSet ToHashSet([NotNull] this IEnumerable @this, IEqualityComparer comparer) + => new HashSet(@this, comparer); + + [NotNull] + public static HashSet ToHashSet([NotNull] this IEnumerable @this) + => new HashSet(@this); + + [NotNull] + public static Dictionary ToDictionary([NotNull] this IEnumerable<(TKey key, TValue value)> @this) + => @this.ToDictionary(item => item.key, item => item.value); + + public static IEnumerable Append([NotNull] this IEnumerable @this, T value) + { + foreach (var i in @this) + yield return i; + yield return value; + } + + public static IEnumerable Prepend([NotNull] this IEnumerable @this, T value) + { + yield return value; + foreach (var i in @this) + yield return i; + } + + public static bool IsNullOrEmpty([CanBeNull] this IEnumerable @this) + => @this == null || !@this.Any(); + } +} diff --git a/source/Unity.Core/Extensions.cs b/source/Unity.Core/Extensions.cs new file mode 100644 index 0000000..e287543 --- /dev/null +++ b/source/Unity.Core/Extensions.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using JetBrains.Annotations; + +namespace Unity.Core +{ + public static class LegacyExtensions + { + public static IEnumerable FixLegacy([NotNull] this MatchCollection @this) + => @this.Cast(); + } + + public static class ObjectExtensions + { + // fluent operators - note that we're limiting to ref types where needed to avoid accidental boxing + + public static T ToBase(this T @this) => @this; // sometimes you need an inline upcast + + public static T To(this object @this) where T : class => (T)@this; + public static T As(this object @this) where T : class => @this as T; + public static bool Is(this object @this) where T : class => @this is T; + public static bool IsNot(this object @this) where T : class => !(@this is T); + } + + public static class RefTypeExtensions + { + [NotNull] + public static IEnumerable WrapInEnumerable(this T @this) + { yield return @this; } + + [NotNull] + public static IEnumerable WrapInEnumerableOrEmpty([CanBeNull] this T @this) where T : class + => !(@this is null) ? WrapInEnumerable(@this) : Enumerable.Empty(); + } + + public static class TypeExtensions + { + public static object GetDefaultValue([NotNull] this Type @this) + { + object defaultValue = null; + if (@this.IsValueType && @this != typeof(void)) + defaultValue = Activator.CreateInstance(@this); + return defaultValue; + } + } + + public static class ComparableExtensions + { + public static T Clamp([NotNull] this T @this, T min, T max) where T : IComparable + { + if (min.CompareTo(max) > 0) + throw new ArgumentException("'min' cannot be greater than 'max'", nameof(min)); + + if (@this.CompareTo(min) < 0) return min; + if (@this.CompareTo(max) > 0) return max; + return @this; + } + } + + public static class ByteArrayExtensions + { + // if you want to speed this up, see https://stackoverflow.com/q/311165/14582 + public static string ToHexString([NotNull] this byte[] @this) + => BitConverter.ToString(@this).Replace("-", ""); + } +} diff --git a/source/Unity.Core/HasParent.cs b/source/Unity.Core/HasParent.cs new file mode 100644 index 0000000..e61dc66 --- /dev/null +++ b/source/Unity.Core/HasParent.cs @@ -0,0 +1,36 @@ +using System; + +#if WIP + +// the intention here is to provide something similar to ITreeEnumerable, part of an +// ability to write abstract graph algorithms and then easily map onto types with +// existing parent structures. + +namespace Unity.Core +{ + public interface IHasParent + { + T Parent { get; } + } + + public class HasParentDelegate : IHasParent + { + Func m_ParentGetter; + + public HasParentDelegate(T @this, Func parentGetter) + { + This = @this; + m_ParentGetter = parentGetter; + } + + public T Parent => m_ParentGetter(This); + public T This { get; } + } + + public static class HasParentDelegate + { + public static HasParentDelegate Create(T @this, Func parentGetter) + => new HasParentDelegate(@this, parentGetter); + } +} +#endif diff --git a/source/Unity.Core/SafeFile.cs b/source/Unity.Core/SafeFile.cs new file mode 100644 index 0000000..e042a8e --- /dev/null +++ b/source/Unity.Core/SafeFile.cs @@ -0,0 +1,50 @@ +using System; +using System.IO; + +namespace Unity.Core +{ + public static class SafeFile + { + // TODO: add tests (see https://stackoverflow.com/a/1528151/14582) + public static void AtomicWrite(string path, Action write) + { + // note that File.Delete doesn't throw if file doesn't exist + + // dotnet doesn't have an atomic move operation (have to pinvoke to something in the OS to get that, + // and even then on windows it's not guaranteed). so the "atomic" part of this name is just to ensure + // that partially written file never happens. + + var tmpPath = path + ".tmp"; + + try + { + File.Delete(tmpPath); + write(tmpPath); + + // temporarily keep the old file, until we're sure the new file is moved + var bakPath = path + ".bak"; + File.Delete(bakPath); + File.Move(path, bakPath); + + File.Move(tmpPath, path); + + // now the old one can go away + // FUTURE: based on option to func, keep bak file + File.Delete(bakPath); + } + finally + { + try + { + File.Delete(tmpPath); + } + catch + { + // failure to cleanup a tmp file isn't critical + } + } + + // FUTURE: options to throw on existing/auto-overwrite + } + } +} diff --git a/source/Unity.Core/StringExtensions.cs b/source/Unity.Core/StringExtensions.cs new file mode 100644 index 0000000..cfe525f --- /dev/null +++ b/source/Unity.Core/StringExtensions.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using JetBrains.Annotations; + +namespace Unity.Core +{ + public static class StringExtensions + { + [ContractAnnotation("null=>true", true), Pure] + public static bool IsNullOrEmpty([CanBeNull] this string @this) => string.IsNullOrEmpty(@this); + [ContractAnnotation("null=>true", true), Pure] + public static bool IsNullOrWhiteSpace([CanBeNull] this string @this) => string.IsNullOrWhiteSpace(@this); + + public static bool IsEmpty([NotNull] this string @this) => @this.Length == 0; + public static bool Any([NotNull] this string @this) => @this.Length != 0; + + // left/mid/right are 'basic' inspired names, and never throw + + [NotNull] + public static string Left([NotNull] this string @this, int maxChars) + { + return @this.Substring(0, Math.Min(maxChars, @this.Length)); + } + + [NotNull] + public static string Mid([NotNull] this string @this, int offset, int maxChars = -1) + { + if (offset < 0) + throw new ArgumentException("offset must be >= 0", nameof(offset)); + + var safeOffset = offset.Clamp(0, @this.Length); + var actualMaxChars = @this.Length - safeOffset; + + var safeMaxChars = maxChars < 0 ? actualMaxChars : Math.Min(maxChars, actualMaxChars); + + return @this.Substring(safeOffset, safeMaxChars); + } + + [NotNull] + public static string Right([NotNull] this string @this, int maxChars) + { + var safeMaxChars = Math.Min(maxChars, @this.Length); + return @this.Substring(@this.Length - safeMaxChars, safeMaxChars); + } + + [NotNull] + public static string StringJoin([NotNull] this IEnumerable @this, [NotNull] string separator) + => string.Join(separator, @this.Cast()); + + [NotNull] + public static string StringJoin([NotNull] this IEnumerable @this, char separator) + => string.Join(new string(separator, 1), @this.Cast()); + + [NotNull] + public static string StringJoin([NotNull] this IEnumerable @this, [NotNull] Func selector, [NotNull] string separator) + => string.Join(separator, @this.Select(selector)); + + [NotNull] + public static string StringJoin([NotNull] this IEnumerable @this, [NotNull] Func selector, char separator) + => string.Join(new string(separator, 1), @this.Select(selector)); + } +} diff --git a/source/Unity.Core/Unity.Core.csproj b/source/Unity.Core/Unity.Core.csproj new file mode 100644 index 0000000..436fdcb --- /dev/null +++ b/source/Unity.Core/Unity.Core.csproj @@ -0,0 +1,18 @@ + + + + + + net461 + + + + + 11.1.0 + + + 4.4.0 + + + + diff --git a/source/common.targets b/source/common.targets index cfe7b46..398eeb5 100644 --- a/source/common.targets +++ b/source/common.targets @@ -1,5 +1,5 @@ - + - $(DefaultItemExcludes);**\*.bak + $(DefaultItemExcludes);**\*.bak;**\*.orig diff --git a/tests/NSubstitute.Elevated.Tests/ElevatedWeaverTests.cs b/tests/NSubstitute.Elevated.Tests/ElevatedWeaverTests.cs new file mode 100644 index 0000000..27b0f35 --- /dev/null +++ b/tests/NSubstitute.Elevated.Tests/ElevatedWeaverTests.cs @@ -0,0 +1,117 @@ +using System; +using System.Linq; +using NSubstitute.Elevated.Weaver; +using NUnit.Framework; +using Shouldly; + +namespace NSubstitute.Elevated.Tests +{ + + [TestFixture] + public class ElevatedWeaverTests + { + TestAssembly m_FixtureTestAssembly; + + const string k_FixtureTestCode = @" + + using System; + using System.Collections.Generic; + using System.Runtime.InteropServices; + + namespace ShouldNotPatch + { + interface Interface { void Foo(); } // ordinary proxying works + + [StructLayout(LayoutKind.Explicit)] + struct StructWithLayoutAttr { } // don't want to risk breaking things by changing size + + class ClassWithPrivateNestedType + { + class PrivateNested { } // unavailable externally, no point + } + + class ClassWithGeneratedNestedType + { + public IEnumerable Foo() // this causes a state machine type to be generated which shouldn't be patched + { yield return 1; } + } + + } + + namespace ShouldPatch + { + class ClassWithNestedTypes + { + public class PublicNested { } + internal class InternalNested { } + } + } + + "; + + [OneTimeSetUp] + public void OneTimeSetUp() + { + m_FixtureTestAssembly = new TestAssembly(nameof(ElevatedWeaverTests), k_FixtureTestCode); + } + + [OneTimeTearDown] + public void OneTimeTearDown() + { + m_FixtureTestAssembly?.Dispose(); + } + + [Test] + public void Interfaces_ShouldNotPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldNotPatch.Interface"); + MockInjector.IsPatched(type).ShouldBeFalse(); + } + + [Test] + public void PotentiallyBlittableStructs_ShouldNotPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldNotPatch.StructWithLayoutAttr"); + MockInjector.IsPatched(type).ShouldBeFalse(); + } + + [Test] + public void PrivateNestedTypes_ShouldNotPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldNotPatch.ClassWithPrivateNestedType"); + var nestedType = type.NestedTypes.Single(t => t.Name == "PrivateNested"); + MockInjector.IsPatched(nestedType).ShouldBeFalse(); + } + + [Test] + public void GeneratedTypes_ShouldNotPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldNotPatch.ClassWithGeneratedNestedType"); + type.NestedTypes.Count.ShouldBe(1); // this is the yield state machine, will be mangled name + MockInjector.IsPatched(type.NestedTypes[0]).ShouldBeFalse(); + } + + [Test] + public void TopLevelClass_ShouldPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldPatch.ClassWithNestedTypes"); + MockInjector.IsPatched(type).ShouldBeTrue(); + } + + [Test] + public void PublicNestedClasses_ShouldPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldPatch.ClassWithNestedTypes"); + var nestedType = type.NestedTypes.Single(t => t.Name == "PublicNested"); + MockInjector.IsPatched(nestedType).ShouldBeTrue(); + } + + [Test] + public void InternalNestedClasses_ShouldPatch() + { + var type = m_FixtureTestAssembly.GetType("ShouldPatch.ClassWithNestedTypes"); + var nestedType = type.NestedTypes.Single(t => t.Name == "InternalNested"); + MockInjector.IsPatched(nestedType).ShouldBeTrue(); + } + } +} diff --git a/tests/NSubstitute.Elevated.Tests/MockWeaverTestUtils.cs b/tests/NSubstitute.Elevated.Tests/MockWeaverTestUtils.cs deleted file mode 100644 index 868b1cd..0000000 --- a/tests/NSubstitute.Elevated.Tests/MockWeaverTestUtils.cs +++ /dev/null @@ -1,8 +0,0 @@ -using System; - -namespace NSubstitute.Elevated.Tests -{ - public static class MockWeaverTestUtils - { - } -} diff --git a/tests/NSubstitute.Elevated.Tests/MockWeaverTests.cs b/tests/NSubstitute.Elevated.Tests/MockWeaverTests.cs deleted file mode 100644 index 7945801..0000000 --- a/tests/NSubstitute.Elevated.Tests/MockWeaverTests.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System; -using NUnit.Framework; - -#pragma warning disable 169 -// ReSharper disable InconsistentNaming - -namespace NSubstitute.Elevated.Tests -{ - [TestFixture] - public class MockWeaverTests - { - [NonSerialized] - object __mockContext; - [NonSerialized] - static object __mockStaticContext; - - [Test] - public void NonParamStaticMethod() - { - } - } -} diff --git a/tests/NSubstitute.Elevated.Tests/NSubstitute.Elevated.Tests.csproj b/tests/NSubstitute.Elevated.Tests/NSubstitute.Elevated.Tests.csproj index d67d0ab..7f32e17 100644 --- a/tests/NSubstitute.Elevated.Tests/NSubstitute.Elevated.Tests.csproj +++ b/tests/NSubstitute.Elevated.Tests/NSubstitute.Elevated.Tests.csproj @@ -8,19 +8,16 @@ - 11.0.0 - - - 0.9.6.4 + 11.1.0 2.0.3 - 3.5.0 + 3.8.1 - 2.8.2 + 2.8.3 diff --git a/tests/NSubstitute.Elevated.Tests/TestAssembly.cs b/tests/NSubstitute.Elevated.Tests/TestAssembly.cs new file mode 100644 index 0000000..69cb6f5 --- /dev/null +++ b/tests/NSubstitute.Elevated.Tests/TestAssembly.cs @@ -0,0 +1,67 @@ +using System; +using System.CodeDom.Compiler; +using System.IO; +using System.Linq; +using System.Reflection; +using Mono.Cecil; +using NSubstitute.Elevated.Weaver; +using Shouldly; +using Unity.Core; + +namespace NSubstitute.Elevated.Tests +{ + public class TestAssembly : IDisposable + { + string m_TestAssemblyPath; + AssemblyDefinition m_TestAssembly; + + public TestAssembly(string assemblyName, string testSourceCodeFile) + { + var outputFolder = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + outputFolder.ShouldNotBeNull(); + + m_TestAssemblyPath = Path.Combine(outputFolder, assemblyName + ".dll"); + + var compiler = new Microsoft.CSharp.CSharpCodeProvider(); + var compilerArgs = new CompilerParameters + { + OutputAssembly = m_TestAssemblyPath, + IncludeDebugInformation = true, + CompilerOptions = "/o- /debug+ /warn:0" + }; + compilerArgs.ReferencedAssemblies.Add(typeof(Enumerable).Assembly.Location); + + var compilerResult = compiler.CompileAssemblyFromSource(compilerArgs, testSourceCodeFile); + if (compilerResult.Errors.Count > 0) + { + var errorText = compilerResult.Errors + .OfType() + .Select(e => $"({e.Line},{e.Column}): error {e.ErrorNumber}: {e.ErrorText}") + .Prepend("Compiler errors:") + .StringJoin("\n"); + throw new Exception(errorText); + } + + m_TestAssemblyPath = compilerResult.PathToAssembly; + + var results = ElevatedWeaver.PatchAllDependentAssemblies(m_TestAssemblyPath, PatchTestAssembly.Yes); + results.Count.ShouldBe(2); + results.ShouldContain(new PatchResult("mscorlib", null, PatchState.IgnoredOutsideAllowedPaths)); + results.ShouldContain(new PatchResult(m_TestAssemblyPath, ElevatedWeaver.GetPatchBackupPathFor(m_TestAssemblyPath), PatchState.Patched)); + + m_TestAssembly = AssemblyDefinition.ReadAssembly(m_TestAssemblyPath); + MockInjector.IsPatched(m_TestAssembly).ShouldBeTrue(); + } + + public void Dispose() + { + m_TestAssembly.Dispose(); + + var dir = new DirectoryInfo(Path.GetDirectoryName(m_TestAssemblyPath)); + foreach (var file in dir.EnumerateFiles(Path.GetFileNameWithoutExtension(m_TestAssemblyPath) + ".*")) + File.Delete(file.FullName); + } + + public TypeDefinition GetType(string typeName) => m_TestAssembly.MainModule.GetType(typeName); + } +}