Merged PR 2644: add a visitor pattern for code rewriting transforms.

add a visitor pattern for code rewriting transforms.

Related work items: #4442, #4496, #4644
This commit is contained in:
Chris Lovett 2020-07-09 00:57:41 +00:00
Родитель 4505028803
Коммит e9a8a22771
9 изменённых файлов: 1211 добавлений и 558 удалений

Просмотреть файл

@ -24,6 +24,11 @@ namespace Microsoft.Coyote.Tasks
/// </summary>
protected readonly object SyncObject;
/// <summary>
/// Whether the lock was taken.
/// </summary>
private bool LockTaken;
/// <summary>
/// Initializes a new instance of the <see cref="SynchronizedBlock"/> class.
/// </summary>
@ -47,7 +52,7 @@ namespace Microsoft.Coyote.Tasks
/// <returns>The synchronized block.</returns>
protected virtual SynchronizedBlock EnterLock()
{
SystemMonitor.Enter(this.SyncObject);
SystemMonitor.Enter(this.SyncObject, ref this.LockTaken);
return this;
}
@ -97,7 +102,7 @@ namespace Microsoft.Coyote.Tasks
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (disposing)
if (disposing && this.LockTaken)
{
SystemMonitor.Exit(this.SyncObject);
}

Просмотреть файл

@ -8,10 +8,7 @@ using System.Linq;
using Microsoft.Coyote.IO;
using Mono.Cecil;
using Mono.Cecil.Cil;
using ControlledTasks = Microsoft.Coyote.SystematicTesting.Interception;
using CoyoteTasks = Microsoft.Coyote.Tasks;
using SystemCompiler = System.Runtime.CompilerServices;
using SystemTasks = System.Threading.Tasks;
namespace Microsoft.Coyote.Rewriting
{
@ -35,21 +32,9 @@ namespace Microsoft.Coyote.Rewriting
private readonly Configuration Configuration;
/// <summary>
/// Cache from <see cref="SystemTasks"/> type names to types being replaced
/// in the module that is currently being rewritten.
/// List of transforms we are applying while rewriting.
/// </summary>
private readonly Dictionary<string, TypeReference> TaskTypeCache;
/// <summary>
/// Cache from <see cref="SystemCompiler"/> type names to types being replaced
/// in the module that is currently being rewritten.
/// </summary>
private readonly Dictionary<string, TypeReference> CompilerTypeCache;
/// <summary>
/// Cached generic task type name prefix.
/// </summary>
private const string GenericTaskTypeNamePrefix = "Task`";
private readonly List<AssemblyTransform> Transforms = new List<AssemblyTransform>();
/// <summary>
/// Initializes a new instance of the <see cref="AssemblyRewriter"/> class.
@ -58,8 +43,8 @@ namespace Microsoft.Coyote.Rewriting
private AssemblyRewriter(Configuration configuration)
{
this.Configuration = configuration;
this.TaskTypeCache = new Dictionary<string, TypeReference>();
this.CompilerTypeCache = new Dictionary<string, TypeReference>();
this.Transforms.Add(new TaskTransform());
this.Transforms.Add(new LockTransform());
}
/// <summary>
@ -167,8 +152,11 @@ namespace Microsoft.Coyote.Rewriting
/// </summary>
private void RewriteModule(ModuleDefinition module)
{
this.TaskTypeCache.Clear();
this.CompilerTypeCache.Clear();
foreach (var t in this.Transforms)
{
t.VisitModule(module);
}
foreach (var type in module.GetTypes())
{
this.RewriteType(type);
@ -181,13 +169,17 @@ namespace Microsoft.Coyote.Rewriting
private void RewriteType(TypeDefinition type)
{
Debug.WriteLine($"....... Rewriting type '{type.FullName}'");
foreach (var t in this.Transforms)
{
t.VisitType(type);
}
foreach (var field in type.Fields)
{
if (this.TryGetCompilerTypeReplacement(field.FieldType, null, out TypeReference newFieldType))
foreach (var t in this.Transforms)
{
Debug.WriteLine($"......... [-] field '{field}'");
field.FieldType = newFieldType;
Debug.WriteLine($"......... [+] field '{field}'");
t.VisitField(field);
}
}
@ -204,464 +196,49 @@ namespace Microsoft.Coyote.Rewriting
{
Debug.WriteLine($"......... Rewriting method '{method.FullName}'");
if (this.TryGetCompilerTypeReplacement(method.ReturnType, method, out TypeReference newReturnType))
foreach (var t in this.Transforms)
{
Debug.WriteLine($"........... [-] return type '{method.ReturnType}'");
method.ReturnType = newReturnType;
Debug.WriteLine($"........... [+] return type '{method.ReturnType}'");
t.VisitMethod(method);
}
// Only non-abstract method bodies can be rewritten.
if (!method.IsAbstract)
{
// TODO: Check if method.Body.ExceptionHandlers needs to be rewritten.
ILProcessor processor = method.Body.GetILProcessor();
foreach (var variable in method.Body.Variables)
{
if (this.TryGetCompilerTypeReplacement(variable.VariableType, method, out TypeReference newVariableType))
foreach (var t in this.Transforms)
{
Debug.WriteLine($"........... [-] variable '{variable.VariableType}'");
variable.VariableType = newVariableType;
Debug.WriteLine($"........... [+] variable '{variable.VariableType}'");
t.VisitVariable(variable);
}
}
Instruction instruction = method.Body.Instructions.FirstOrDefault();
while (instruction != null)
// Do exception handlers before the method instructions because they are a
// higher level concept and it's handy to pre-process them before seeing the
// raw instructions.
if (method.Body.HasExceptionHandlers)
{
if (instruction.OpCode == OpCodes.Stfld || instruction.OpCode == OpCodes.Ldfld || instruction.OpCode == OpCodes.Ldflda)
foreach (var t in this.Transforms)
{
if (instruction.Operand is FieldDefinition fd &&
this.TryGetCompilerTypeReplacement(fd.FieldType, method, out TypeReference newFieldType))
foreach (var handler in method.Body.ExceptionHandlers)
{
Debug.WriteLine($"........... [-] {instruction}");
fd.FieldType = newFieldType;
Debug.WriteLine($"........... [+] {instruction}");
}
else if (instruction.Operand is FieldReference fr &&
this.TryGetCompilerTypeReplacement(fr.FieldType, method, out newFieldType))
{
Debug.WriteLine($"........... [-] {instruction}");
fr.FieldType = newFieldType;
Debug.WriteLine($"........... [+] {instruction}");
t.VisitExceptionHandler(handler);
}
}
else
}
// in this case run each transform as separate passes over the method body
// so they don't trip over each other's edits.
foreach (var t in this.Transforms)
{
Instruction instruction = method.Body.Instructions.FirstOrDefault();
while (instruction != null)
{
ILProcessingOperation operation = this.GetILProcessingOperation(instruction, method);
if (operation.Type is ILProcessingOperationType.Replace)
{
Debug.WriteLine($"........... [-] {instruction}");
operation.Instructions[0].Offset = instruction.Offset;
processor.Replace(instruction, operation.Instructions[0]);
instruction = operation.Instructions[0];
Debug.WriteLine($"........... [+] {instruction}");
for (int idx = 1; idx < operation.Instructions.Length; idx++)
{
Debug.WriteLine($"........... [+] {operation.Instructions[idx]}");
processor.InsertAfter(instruction, operation.Instructions[idx]);
instruction = instruction.Next;
}
}
}
instruction = instruction.Next;
}
}
}
/// <summary>
/// Returns an <see cref="ILProcessingOperation"/> for the specified instruction.
/// </summary>
/// <remarks>
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
/// </remarks>
private ILProcessingOperation GetILProcessingOperation(Instruction instruction, MethodDefinition method)
{
// TODO: check what we need to deal with `OpCodes.Calli`, and if we need to.
ILProcessingOperation operation = ILProcessingOperation.None;
if (instruction.OpCode == OpCodes.Initobj)
{
operation = this.GetInitobjProcessingOperation(instruction, method);
}
else if ((instruction.OpCode == OpCodes.Call || instruction.OpCode == OpCodes.Callvirt) &&
instruction.Operand is MethodReference methodReference)
{
operation = this.GetCallProcessingOperation(instruction.OpCode, methodReference, method);
}
return operation;
}
/// <summary>
/// Returns an <see cref="ILProcessingOperation"/> for the specified <see cref="OpCodes.Initobj"/> instruction.
/// </summary>
/// <remarks>
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
/// </remarks>
private ILProcessingOperation GetInitobjProcessingOperation(Instruction instruction, MethodDefinition method)
{
ILProcessingOperation operation = ILProcessingOperation.None;
TypeReference type = instruction.Operand as TypeReference;
if (this.TryGetCompilerTypeReplacement(type, method, out TypeReference newType))
{
var newInstruction = Instruction.Create(instruction.OpCode, newType);
operation = new ILProcessingOperation(ILProcessingOperationType.Replace, newInstruction);
}
return operation;
}
/// <summary>
/// Returns an <see cref="ILProcessingOperation"/> for the specified non-generic <see cref="OpCodes.Call"/>
/// or <see cref="OpCodes.Callvirt"/> instruction.
/// </summary>
/// <remarks>
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
/// </remarks>
private ILProcessingOperation GetCallProcessingOperation(OpCode opCode, MethodReference method, IGenericParameterProvider provider)
{
TypeReference newType = null;
if (IsSystemTaskType(method.DeclaringType))
{
// Special rules apply for methods under the Task namespace.
if (method.Name == nameof(SystemTasks.Task.Run) ||
method.Name == nameof(SystemTasks.Task.Delay) ||
method.Name == nameof(SystemTasks.Task.WhenAll) ||
method.Name == nameof(SystemTasks.Task.WhenAny) ||
method.Name == nameof(SystemTasks.Task.Yield) ||
method.Name == nameof(SystemTasks.Task.GetAwaiter))
{
newType = this.GetTaskTypeReplacement(method.DeclaringType);
}
}
else
{
newType = this.GetCompilerTypeReplacement(method.DeclaringType, provider);
if (newType == method.DeclaringType)
{
newType = null;
}
}
// TODO: check if "new type is null" check is required.
if (newType is null || !this.TryGetReplacementMethod(newType, method, out MethodReference newMethod))
{
// There is nothing to rewrite, return with the none operation.
return ILProcessingOperation.None;
}
OpCode newOpCode = opCode;
if (newMethod.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter))
{
// The OpCode must change for the GetAwaiter method.
newOpCode = OpCodes.Call;
}
// Create and return the new instruction.
Instruction newInstruction = Instruction.Create(newOpCode, newMethod);
return new ILProcessingOperation(ILProcessingOperationType.Replace, newInstruction);
}
/// <summary>
/// Returns a method from the specified type that can replace the original method, if any.
/// </summary>
private bool TryGetReplacementMethod(TypeReference replacementType, MethodReference originalMethod, out MethodReference result)
{
result = null;
TypeDefinition replacementTypeDef = replacementType.Resolve();
if (replacementTypeDef == null)
{
throw new Exception(string.Format("Error resolving type: {0}", replacementType.FullName));
}
MethodDefinition original = originalMethod.Resolve();
bool isGetControlledAwaiter = false;
foreach (var replacement in replacementTypeDef.Methods)
{
// TODO: make sure all necessery checks are in place!
if (!(!replacement.IsConstructor &&
replacement.Name == original.Name &&
replacement.ReturnType.IsGenericInstance == original.ReturnType.IsGenericInstance &&
replacement.IsPublic == original.IsPublic &&
replacement.IsPrivate == original.IsPrivate &&
replacement.IsAssembly == original.IsAssembly &&
replacement.IsFamilyAndAssembly == original.IsFamilyAndAssembly))
{
continue;
}
isGetControlledAwaiter = replacement.DeclaringType.Namespace == KnownNamespaces.ControlledTasksName &&
replacement.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter);
if (!isGetControlledAwaiter)
{
// Only check that the parameters match for non-controlled awaiter methods.
// That is because we do special rewriting for this method.
if (!CheckMethodParametersMatch(replacement, original))
{
continue;
instruction = t.VisitInstruction(instruction);
instruction = instruction.Next;
}
}
// Import the method in the module that is being rewritten.
result = originalMethod.Module.ImportReference(replacement);
break;
}
if (result is null)
{
// TODO: raise an error.
return false;
}
result.DeclaringType = replacementType;
if (originalMethod is GenericInstanceMethod genericMethod)
{
var newGenericMethod = new GenericInstanceMethod(result);
// The method is generic, so populate it with generic argument types and parameters.
foreach (var arg in genericMethod.GenericArguments)
{
TypeReference newArgumentType = this.GetCompilerTypeReplacement(arg, newGenericMethod);
newGenericMethod.GenericArguments.Add(newArgumentType);
}
result = newGenericMethod;
}
else if (isGetControlledAwaiter && originalMethod.DeclaringType is GenericInstanceType genericType)
{
// Special processing applies in this case, because we are converting the `Task<T>.GetAwaiter`
// non-generic instance method to the `ControlledTask.GetAwaiter<T>` generic static method.
var newGenericMethod = new GenericInstanceMethod(result);
// There is only a single argument type, which must be added to ma.
newGenericMethod.GenericArguments.Add(this.GetCompilerTypeReplacement(genericType.GenericArguments[0], newGenericMethod));
// The single generic argument type in the task parameter must be rewritten to the same
// generic argument type as the one in the return type of `GetAwaiter<T>`.
var parameterType = newGenericMethod.Parameters[0].ParameterType as GenericInstanceType;
parameterType.GenericArguments[0] = (newGenericMethod.ReturnType as GenericInstanceType).GenericArguments[0];
result = newGenericMethod;
}
// Rewrite the parameters of the method, if any.
for (int idx = 0; idx < originalMethod.Parameters.Count; idx++)
{
ParameterDefinition parameter = originalMethod.Parameters[idx];
TypeReference newParameterType = this.GetCompilerTypeReplacement(parameter.ParameterType, result);
ParameterDefinition newParameter = new ParameterDefinition(parameter.Name, parameter.Attributes, newParameterType);
result.Parameters[idx] = newParameter;
}
if (result.ReturnType.Namespace != KnownNamespaces.ControlledTasksName)
{
result.ReturnType = this.GetCompilerTypeReplacement(originalMethod.ReturnType, result);
}
return originalMethod.FullName != result.FullName;
}
/// <summary>
/// Checks if the the parameters of the two methods match.
/// </summary>
private static bool CheckMethodParametersMatch(MethodDefinition left, MethodDefinition right)
{
if (left.Parameters.Count != right.Parameters.Count)
{
return false;
}
for (int idx = 0; idx < right.Parameters.Count; idx++)
{
var originalParam = right.Parameters[0];
var replacementParam = left.Parameters[0];
// TODO: make sure all necessery checks are in place!
if ((replacementParam.ParameterType.FullName != originalParam.ParameterType.FullName) ||
(replacementParam.Name != originalParam.Name) ||
(replacementParam.IsIn && !originalParam.IsIn) ||
(replacementParam.IsOut && !originalParam.IsOut))
{
return false;
}
}
return true;
}
/// <summary>
/// Returns the replacement type for the specified <see cref="SystemTasks"/> type, else null.
/// </summary>
private TypeReference GetTaskTypeReplacement(TypeReference type)
{
TypeReference result;
string fullName = type.FullName;
if (this.TaskTypeCache.ContainsKey(fullName))
{
result = this.TaskTypeCache[fullName];
if (result.Module != type.Module)
{
result = type.Module.ImportReference(result);
this.TaskTypeCache[fullName] = result;
}
}
else
{
result = type.Module.ImportReference(typeof(ControlledTasks.ControlledTask));
this.TaskTypeCache[fullName] = result;
}
return result;
}
/// <summary>
/// Tries to return the replacement type for the specified <see cref="SystemCompiler"/> type, if such a type exists.
/// </summary>
private bool TryGetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider, out TypeReference result)
{
result = this.GetCompilerTypeReplacement(type, provider);
return result.FullName != type.FullName;
}
/// <summary>
/// Returns the replacement type for the specified <see cref="SystemCompiler"/> type, else null.
/// </summary>
private TypeReference GetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider)
{
TypeReference result = type;
string fullName = type.FullName;
if (this.CompilerTypeCache.ContainsKey(fullName))
{
result = this.CompilerTypeCache[fullName];
if (result.Module != type.Module)
{
result = type.Module.ImportReference(result);
this.CompilerTypeCache[fullName] = result;
}
}
else if (type.IsGenericInstance &&
(type.Name == KnownSystemTypes.GenericAsyncTaskMethodBuilderName ||
type.Name == KnownSystemTypes.GenericTaskAwaiterName))
{
result = this.GetGenericTypeReplacement(type as GenericInstanceType, provider);
if (result.FullName != fullName)
{
result = type.Module.ImportReference(result);
this.CompilerTypeCache[fullName] = result;
}
}
else if (fullName == KnownSystemTypes.AsyncTaskMethodBuilderFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder));
}
else if (fullName == KnownSystemTypes.GenericAsyncTaskMethodBuilderFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder<>));
}
else if (fullName == KnownSystemTypes.TaskAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter));
}
else if (fullName == KnownSystemTypes.GenericTaskAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter<>));
}
else if (fullName == KnownSystemTypes.YieldAwaitableFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable));
}
else if (fullName == KnownSystemTypes.YieldAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable.YieldAwaiter));
}
return result;
}
/// <summary>
/// Import the replacement coyote type and cache it in the CompilerTypeCache and make sure the type can be
/// fully resolved.
/// </summary>
private TypeReference ImportCompilerTypeReplacement(TypeReference originalType, Type coyoteType)
{
var result = originalType.Module.ImportReference(coyoteType);
this.CompilerTypeCache[originalType.FullName] = result;
TypeDefinition coyoteTypeDef = result.Resolve();
if (coyoteTypeDef == null)
{
throw new Exception(string.Format("Unexpected error resolving type: {0}", coyoteType.FullName));
}
return result;
}
/// <summary>
/// Returns the replacement type for the specified generic type, else null.
/// </summary>
private GenericInstanceType GetGenericTypeReplacement(GenericInstanceType type, IGenericParameterProvider provider)
{
GenericInstanceType result = type;
TypeReference genericType = this.GetCompilerTypeReplacement(type.ElementType, null);
if (type.ElementType.FullName != genericType.FullName)
{
// The generic type must be rewritten.
result = new GenericInstanceType(type.Module.ImportReference(genericType));
foreach (var arg in type.GenericArguments)
{
TypeReference newArgumentType;
if (arg.IsGenericParameter)
{
GenericParameter parameter = new GenericParameter(arg.Name, provider ?? result);
result.GenericParameters.Add(parameter);
newArgumentType = parameter;
}
else
{
newArgumentType = this.GetCompilerTypeReplacement(arg, provider);
}
result.GenericArguments.Add(newArgumentType);
}
}
return result;
}
/// <summary>
/// Checks if the specified type is the <see cref="SystemTasks.Task"/> type.
/// </summary>
private static bool IsSystemTaskType(TypeReference type) => type.Namespace == KnownNamespaces.SystemTasksName &&
(type.Name == typeof(SystemTasks.Task).Name || type.Name.StartsWith(GenericTaskTypeNamePrefix));
/// <summary>
/// Cache of known <see cref="SystemCompiler"/> type names.
/// </summary>
private static class KnownSystemTypes
{
internal static string AsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder).FullName;
internal static string GenericAsyncTaskMethodBuilderName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).Name;
internal static string GenericAsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).FullName;
internal static string TaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter).FullName;
internal static string GenericTaskAwaiterName { get; } = typeof(SystemCompiler.TaskAwaiter<>).Name;
internal static string GenericTaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter<>).FullName;
internal static string YieldAwaitableFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName;
internal static string YieldAwaiterFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName + "/YieldAwaiter";
}
/// <summary>
/// Cache of known namespace names.
/// </summary>
private static class KnownNamespaces
{
internal static string ControlledTasksName { get; } = typeof(ControlledTasks.ControlledTask).Namespace;
internal static string SystemTasksName { get; } = typeof(SystemTasks.Task).Namespace;
}
}
}

Просмотреть файл

@ -1,64 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using Mono.Cecil.Cil;
namespace Microsoft.Coyote.Rewriting
{
/// <summary>
/// An IL processing operation to perform.
/// </summary>
internal struct ILProcessingOperation
{
/// <summary>
/// A cached no-op IL processing operation.
/// </summary>
internal static ILProcessingOperation None { get; } = new ILProcessingOperation(ILProcessingOperationType.None);
/// <summary>
/// A cached array containing a single instruction to optimize for memory in the common scenario.
/// </summary>
/// <remarks>
/// This is not thread safe, but we are not running rewriting in parallel.
/// </remarks>
private static Instruction[] SingleInstruction { get; } = new Instruction[1];
/// <summary>
/// The type of IL processing to perform.
/// </summary>
internal readonly ILProcessingOperationType Type;
/// <summary>
/// The new instructions to add.
/// </summary>
internal readonly Instruction[] Instructions;
/// <summary>
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
/// </summary>
private ILProcessingOperation(ILProcessingOperationType type)
{
this.Type = type;
this.Instructions = null;
}
/// <summary>
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
/// </summary>
internal ILProcessingOperation(ILProcessingOperationType type, Instruction instruction)
{
this.Type = type;
SingleInstruction[0] = instruction;
this.Instructions = SingleInstruction;
}
/// <summary>
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
/// </summary>
internal ILProcessingOperation(ILProcessingOperationType type, params Instruction[] instructions)
{
this.Type = type;
this.Instructions = instructions;
}
}
}

Просмотреть файл

@ -1,31 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
namespace Microsoft.Coyote.Rewriting
{
/// <summary>
/// The type of IL processing to perform.
/// </summary>
internal enum ILProcessingOperationType
{
/// <summary>
/// Do not change the current instruction.
/// </summary>
None,
/// <summary>
/// Insert the specified instructions after the current instruction.
/// </summary>
InsertAfter,
/// <summary>
/// Insert the specified instructions before the current instruction.
/// </summary>
InsertBefore,
/// <summary>
/// Replace the current instruction with the specified instructions.
/// </summary>
Replace
}
}

Просмотреть файл

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using Mono.Cecil;
using Mono.Cecil.Cil;
namespace Microsoft.Coyote.Rewriting
{
/// <summary>
/// An abstract interface for transforming code using a visitor pattern.
/// This is used by the <see cref="AssemblyRewriter"/> to manage multiple different
/// transforms in a single pass over an assembly.
/// </summary>
internal abstract class AssemblyTransform
{
/// <summary>
/// Notify visitor we are starting a new Module.
/// </summary>
internal abstract void VisitModule(ModuleDefinition module);
/// <summary>
/// Notify visitor we are visiting a new TypeDefinition.
/// </summary>
internal abstract void VisitType(TypeDefinition typeDef);
/// <summary>
/// Notify visitor we are visiting a field inside the TypeDefinition just given to VisitType.
/// </summary>
internal abstract void VisitField(FieldDefinition field);
/// <summary>
/// Notify visitor we are visiting a method inside the TypeDefinition just given to VisitType.
/// </summary>
internal abstract void VisitMethod(MethodDefinition method);
/// <summary>
/// Notify visitor we are visiting a variable declaration inside the MethodDefinition just given to VisitMethod.
/// </summary>
internal abstract void VisitVariable(VariableDefinition variable);
/// <summary>
/// Visit an IL instruction inside the MethodDefinition body, and get back a possibly transformed instruction.
/// </summary>
/// <param name="instruction">The instruction to visit.</param>
/// <returns>Return the last modified instruction or the same one if it was not changed.</returns>
internal abstract Instruction VisitInstruction(Instruction instruction);
/// <summary>
/// Visit an <see cref="ExceptionHandler"/> inside the MethodDefinition. In the case of nested try/catch blocks
/// the inner block is visited first before the outer block.
/// </summary>
internal abstract void VisitExceptionHandler(ExceptionHandler handler);
}
}

Просмотреть файл

@ -0,0 +1,484 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Coyote.IO;
using Microsoft.Coyote.Specifications;
using Mono.Cecil;
using Mono.Cecil.Cil;
namespace Microsoft.Coyote.Rewriting
{
internal class LockTransform : AssemblyTransform
{
private ModuleDefinition Module;
private TypeDefinition TypeDef;
private MethodDefinition Method;
private ILProcessor Processor;
private List<SyncObjectMapping> Mapping;
private const string MonitorClassName = "System.Threading.Monitor";
/// <summary>
/// Maintains a mapping from the "syncobject" used in Monitor::Enter to the
/// actual instance of SynchronizedBlock that is now wrapping that sync object.
/// </summary>
private class SyncObjectMapping
{
internal FieldDefinition SyncObjectField;
internal VariableDefinition SyncObjectVariable;
internal VariableDefinition SyncBlockVariable;
}
/// <inheritdoc/>
internal override void VisitModule(ModuleDefinition module)
{
this.Module = module;
}
/// <inheritdoc/>
internal override void VisitType(TypeDefinition typeDef)
{
this.TypeDef = typeDef;
}
/// <inheritdoc/>
internal override void VisitField(FieldDefinition field)
{
}
/// <inheritdoc/>
internal override void VisitMethod(MethodDefinition method)
{
this.Method = method;
this.Processor = method.Body.GetILProcessor();
this.Mapping = new List<SyncObjectMapping>();
}
internal override void VisitVariable(VariableDefinition variable)
{
}
/// <inheritdoc/>
internal override Instruction VisitInstruction(Instruction instruction)
{
if (this.Method == null)
{
return instruction;
}
if (instruction.OpCode == OpCodes.Call && instruction.Operand is MethodReference method)
{
const string PulseMethod = nameof(System.Threading.Monitor.Pulse);
const string PulseAllMethod = nameof(System.Threading.Monitor.PulseAll);
const string WaitMethod = nameof(System.Threading.Monitor.Wait);
if (method.DeclaringType.FullName == MonitorClassName && method.Name == PulseMethod)
{
Debug.WriteLine($"......... [-] call '{method}'");
var pulseMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(PulseMethod));
this.ReplaceLoadSyncBlock(instruction);
var newInstruction = Instruction.Create(OpCodes.Callvirt, pulseMethod);
Debug.WriteLine($"......... [+] call '{pulseMethod}'");
this.Processor.Replace(instruction, newInstruction);
instruction = newInstruction;
}
else if (method.DeclaringType.FullName == MonitorClassName && method.Name == PulseAllMethod)
{
Debug.WriteLine($"......... [-] call '{method}'");
var pulseAllMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(PulseAllMethod));
this.ReplaceLoadSyncBlock(instruction);
var newInstruction = Instruction.Create(OpCodes.Callvirt, pulseAllMethod);
Debug.WriteLine($"......... [+] call '{pulseAllMethod}'");
this.Processor.Replace(instruction, newInstruction);
instruction = newInstruction;
}
else if (method.DeclaringType.FullName == MonitorClassName && method.Name == "Wait")
{
// public static bool Wait(object obj);
// public static bool Wait(object obj, int millisecondsTimeout);
// public static bool Wait(object obj, int millisecondsTimeout, bool exitContext);
// public static bool Wait(object obj, TimeSpan timeout);
// public static bool Wait(object obj, TimeSpan timeout, bool exitContext);
if (method.Parameters.Count == 1)
{
Debug.WriteLine($"......... [-] call '{method}'");
var waitMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(WaitMethod, Array.Empty<Type>()));
this.ReplaceLoadSyncBlock(instruction);
var newInstruction = Instruction.Create(OpCodes.Callvirt, waitMethod);
Debug.WriteLine($"......... [+] call '{waitMethod}'");
this.Processor.Replace(instruction, newInstruction);
instruction = newInstruction;
}
else
{
throw new NotImplementedException();
}
}
}
return instruction;
}
/// <inheritdoc/>
internal override void VisitExceptionHandler(ExceptionHandler handler)
{
// a C# lock statement uses try/finally block, where Monitor.Enter is at the beginning of the Try
// and Monitor.Exit is in the finally block. If the Monitor does not follow this pattern then it
// is probably something else.
// The C# lock statement only has finally block, so there should be no CatchType.
if (handler.CatchType == null && MatchLockEnter(handler.TryStart, handler.TryEnd))
{
if (MatchLockExit(handler.HandlerStart, handler.HandlerEnd))
{
this.RewriteLock(handler);
}
}
}
/// <summary>
/// Resolve the FieldDefinition referenced in the given load instruction if any.
/// </summary>
/// <returns>A FieldDefinition or null.</returns>
private static FieldDefinition GetFieldDefinition(Instruction loadInstruction)
{
FieldDefinition fd = null;
if (loadInstruction.Operand is FieldReference fr)
{
fd = fr.Resolve();
}
else if (loadInstruction.Operand is FieldDefinition fdef)
{
fd = fdef;
}
return fd;
}
/// <summary>
/// Rewrite "ldarg.0, ldfld syncobject" with "ldloc.n" to load the local SynchronizedBlock instead
/// of the sync object, because the SynchronizedBlock methods (Pulse, Wait, etc) are instance methods
/// not static methods and they do not take the syncobject as an argument because the SynchronizedBlock
/// stores the sync object so it doesn't need it here.
/// </summary>
internal void ReplaceLoadSyncBlock(Instruction methodCall)
{
Instruction loadSyncObject = methodCall.Previous;
if (loadSyncObject.OpCode == OpCodes.Ldfld)
{
var fd = GetFieldDefinition(loadSyncObject);
if (fd != null)
{
var syncBlock = this.FindSyncBlockVar(fd);
if (syncBlock != null)
{
Instruction loadThis = loadSyncObject.Previous;
if (loadThis.OpCode == OpCodes.Ldarg_0)
{
this.Processor.Remove(loadThis);
}
else
{
throw new InvalidOperationException("Expecting load 'this' instruction here...");
}
this.Processor.Replace(loadSyncObject, CreateLoadOp(syncBlock));
}
else
{
// TODO: this code only works if the Pulse, PulseAll, or Wait method is called inside
// the same method containing the C# lock statement. This is normally the case but if
// someone decided to get clever and call a helper method and that helper method calls
// Wait then I have a problem because the SynchronizedBlock local variable will not be
// available in that new method...
throw new NotImplementedException(string.Format(
"Cannot find the matching SynchronizedBlock for the synchronizing object '{0}' in {1}",
loadSyncObject, this.Method.FullName));
}
}
}
else if (IsLoadOp(loadSyncObject.OpCode))
{
// TODO: can this happen? I haven't seen it yet...
throw new NotImplementedException("Monitor method called using local variable instead of ldfld");
}
}
internal static bool MatchLockEnter(Instruction start, Instruction end)
{
if (CountInstructions(start, end) >= 3)
{
Instruction a = start;
Instruction b = a.Next;
Instruction c = b.Next;
if (IsLoadOp(a.OpCode) && // C# always creates local variable for the sync object
b.OpCode == OpCodes.Ldloca_S && // the LockTaken boolean flag should be local
c.OpCode == OpCodes.Call && c.Operand is MethodReference method && method.DeclaringType.FullName == MonitorClassName && method.Name == "Enter")
{
// looks like a C# lock then!
return true;
}
}
return false;
}
internal static bool MatchLockExit(Instruction start, Instruction end)
{
if (CountInstructions(start, end) >= 4)
{
Instruction a = start;
Instruction b = a.Next;
Instruction c = b.Next;
Instruction d = c.Next;
if (IsLoadOp(a.OpCode) && // checking the LockTaken boolean flag
b.OpCode == OpCodes.Brfalse_S && // no exit if we didn't get the lock!
IsLoadOp(c.OpCode) && // loading the sync object
d.OpCode == OpCodes.Call && d.Operand is MethodReference method && method.DeclaringType.FullName == MonitorClassName && method.Name == "Exit")
{
// looks like a C# lock release then!
return true;
}
}
return false;
}
/// <summary>
/// Replaces Monitor.Enter with SynchronizedBlock.Lock and Monitor.Exit with SynchronizedBlock.Dispose().
/// </summary>
private void RewriteLock(ExceptionHandler handler)
{
// Then we can replace the LockTaken boolean local variable with a new SynchronizedBlock object and all this
// happens outside the try statement, so InsertBefore the start of the try block.
Instruction a = handler.TryStart; // local variable for the sync object
Instruction b = a.Next; // the LockTaken boolean flag should be local
Instruction c = b.Next; // the Monitor.Enter call.
var syncObjectVar = this.GetLocalVariable(a.OpCode, a.Operand);
var lockIndexVar = this.GetLocalVariable(b.OpCode, b.Operand);
// creates: using(var m = SynchronizedBlock.Lock(syncObject)) { ...
Debug.WriteLine($"......... [-] variable '{lockIndexVar.VariableType}'");
Debug.WriteLine($"......... [-] call '{c}'");
lockIndexVar.VariableType = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock)); // re-purpose this variable.
this.InitializeNull(lockIndexVar);
Debug.WriteLine($"......... [+] variable '{lockIndexVar.VariableType}'");
Instruction load_sync_object = CreateLoadOp(syncObjectVar);
this.Processor.InsertAfter(a.Previous, load_sync_object);
var lockMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod("Lock"));
Instruction create_sync_block = Instruction.Create(OpCodes.Call, lockMethod);
Debug.WriteLine($"......... [+] call '{create_sync_block}'");
this.Processor.InsertAfter(load_sync_object, create_sync_block);
Instruction stfield = CreateStoreOp(lockIndexVar);
this.Processor.InsertAfter(create_sync_block, stfield);
this.AddMapping(syncObjectVar, lockIndexVar);
// and now we can remove the first 3 instructions of the try block that were calling Monitor.Enter.
this.Processor.Remove(a);
this.Processor.Remove(b);
this.Processor.Remove(c);
// fix the finally block.
Instruction d = handler.HandlerStart; // ldloc LockTaken becomes ldloc SynchronizedBlock object
Instruction e = d.Next; // brfalse, this remains the same
Instruction f = e.Next; // ldloc sync object, becomes ldloc syncblock variable.
Instruction g = f.Next; // call Monitor.Exit.
// Create: m.Dispose()
Instruction load_sync_block = CreateLoadOp(lockIndexVar); // remember lockIndexVar has been re-purposed to store the SynchronizedBlock object
this.Processor.Replace(f, load_sync_block); // we need to load the object we are calling Dispose on. not the sync object
var disposeMethod = this.Module.ImportReference(typeof(System.IDisposable).GetMethod("Dispose"));
Instruction dispose = Instruction.Create(OpCodes.Callvirt, disposeMethod);
Debug.WriteLine($"......... [-] call '{g}'");
Debug.WriteLine($"......... [+] call '{dispose}'");
this.Processor.Replace(g, dispose);
}
/// <summary>
/// Make sure a variable is initialized to null (when variable was a bool it was initialized to integer 0 instead).
/// </summary>
private void InitializeNull(VariableDefinition v)
{
Instruction instruction = this.Method.Body.Instructions.FirstOrDefault();
while (instruction != null)
{
var op = instruction.OpCode;
if (IsStoreOp(op))
{
VariableDefinition v2 = this.GetLocalVariable(op, instruction.Operand);
if (v2.Index == v.Index)
{
// found the store operation for this local variable, so make sure it is initialized to null
if (instruction.Previous != null && instruction.Previous.OpCode != OpCodes.Ldnull)
{
this.Processor.Replace(instruction.Previous, Instruction.Create(OpCodes.Ldnull));
break;
}
}
}
instruction = instruction.Next;
}
}
/// <summary>
/// Remember the connection between syncObject and SynchronizedBlock and find which
/// FieldDefinition stores the syncObject.
/// </summary>
private void AddMapping(VariableDefinition syncObjectVar, VariableDefinition syncBlockVar)
{
// find the ldarg, ldfld, stloc.x for the syncObject so we know what it's FieldDefinition is.
var mapping = new SyncObjectMapping()
{
SyncObjectVariable = syncObjectVar,
SyncBlockVariable = syncBlockVar
};
Instruction instruction = this.Method.Body.Instructions.FirstOrDefault();
while (instruction != null)
{
var op = instruction.OpCode;
if (IsStoreOp(op))
{
VariableDefinition v = this.GetLocalVariable(op, instruction.Operand);
if (v.Index == syncObjectVar.Index)
{
// found the store operation for this local variable!
if (instruction.Previous != null && instruction.Previous.OpCode == OpCodes.Ldfld)
{
var fd = GetFieldDefinition(instruction.Previous);
if (fd != null)
{
mapping.SyncObjectField = fd;
break;
}
}
}
}
instruction = instruction.Next;
}
if (mapping.SyncObjectField == null)
{
// TODO: hmmm, perhaps the location of the sync object is more complicated...
}
this.Mapping.Add(mapping);
}
private VariableDefinition FindSyncBlockVar(FieldDefinition fd)
{
foreach (var item in this.Mapping)
{
if (item.SyncObjectField == fd)
{
return item.SyncBlockVariable;
}
}
Debug.WriteLine("### Cannot find SynchronizedBlock created for this Synchronizing Object");
return null;
}
internal static bool IsLoadOp(OpCode op)
{
return op == OpCodes.Ldloc_S || op == OpCodes.Ldloc_0 || op == OpCodes.Ldloc_1 || op == OpCodes.Ldloc_2 || op == OpCodes.Ldloc_3;
}
internal static bool IsStoreOp(OpCode op)
{
return op == OpCodes.Stloc_S || op == OpCodes.Stloc_0 || op == OpCodes.Stloc_1 || op == OpCodes.Stloc_2 || op == OpCodes.Stloc_3;
}
internal VariableDefinition GetLocalVariable(OpCode op, object operand)
{
if ((op == OpCodes.Ldloc || op == OpCodes.Ldloc_S || op == OpCodes.Ldloca_S || op == OpCodes.Stloc_S || op == OpCodes.Stloc) && operand is VariableDefinition vdef)
{
return vdef;
}
if (op == OpCodes.Ldloc_0 || op == OpCodes.Stloc_0)
{
return this.Method.Body.Variables[0];
}
if (op == OpCodes.Ldloc_1 || op == OpCodes.Stloc_1)
{
return this.Method.Body.Variables[1];
}
if (op == OpCodes.Ldloc_2 || op == OpCodes.Stloc_2)
{
return this.Method.Body.Variables[2];
}
if (op == OpCodes.Ldloc_3 || op == OpCodes.Stloc_3)
{
return this.Method.Body.Variables[3];
}
throw new InvalidOperationException();
}
private static Instruction CreateLoadOp(VariableDefinition var)
{
switch (var.Index)
{
case 0:
return Instruction.Create(OpCodes.Ldloc_0);
case 1:
return Instruction.Create(OpCodes.Ldloc_1);
case 2:
return Instruction.Create(OpCodes.Ldloc_2);
case 3:
return Instruction.Create(OpCodes.Ldloc_3);
default:
return Instruction.Create(OpCodes.Ldloc, var);
}
}
internal static Instruction CreateStoreOp(VariableDefinition var)
{
switch (var.Index)
{
case 0:
return Instruction.Create(OpCodes.Stloc_0);
case 1:
return Instruction.Create(OpCodes.Stloc_1);
case 2:
return Instruction.Create(OpCodes.Stloc_2);
case 3:
return Instruction.Create(OpCodes.Stloc_3);
default:
return Instruction.Create(OpCodes.Stloc_S, var);
}
}
internal static int CountInstructions(Instruction start, Instruction end)
{
int count = 0;
while (start != end)
{
count++;
start = start.Next;
}
return count;
}
}
}

Просмотреть файл

@ -0,0 +1,541 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using Microsoft.Coyote.IO;
using Mono.Cecil;
using Mono.Cecil.Cil;
using ControlledTasks = Microsoft.Coyote.SystematicTesting.Interception;
using CoyoteTasks = Microsoft.Coyote.Tasks;
using SystemCompiler = System.Runtime.CompilerServices;
using SystemTasks = System.Threading.Tasks;
namespace Microsoft.Coyote.Rewriting
{
internal class TaskTransform : AssemblyTransform
{
/// <summary>
/// Cache from <see cref="SystemTasks"/> type names to types being replaced
/// in the module that is currently being rewritten.
/// </summary>
private readonly Dictionary<string, TypeReference> TaskTypeCache;
/// <summary>
/// Cache from <see cref="SystemCompiler"/> type names to types being replaced
/// in the module that is currently being rewritten.
/// </summary>
private readonly Dictionary<string, TypeReference> CompilerTypeCache;
/// <summary>
/// The current module being transformed.
/// </summary>
private ModuleDefinition Module;
/// <summary>
/// The current type being transformed.
/// </summary>
private TypeDefinition TypeDef;
/// <summary>
/// The current method being transformed.
/// </summary>
private MethodDefinition Method;
/// <summary>
/// A helper class for editing method body.
/// </summary>
private ILProcessor Processor;
/// <summary>
/// Initializes a new instance of the <see cref="TaskTransform"/> class.
/// </summary>
internal TaskTransform()
{
this.TaskTypeCache = new Dictionary<string, TypeReference>();
this.CompilerTypeCache = new Dictionary<string, TypeReference>();
}
/// <summary>
/// Cached generic task type name prefix.
/// </summary>
private const string GenericTaskTypeNamePrefix = "Task`";
internal override void VisitModule(ModuleDefinition module)
{
this.Module = module;
this.TaskTypeCache.Clear();
this.CompilerTypeCache.Clear();
}
internal override void VisitType(TypeDefinition typeDef)
{
this.TypeDef = typeDef;
this.Method = null;
this.Processor = null;
}
internal override void VisitField(FieldDefinition field)
{
if (this.TryGetCompilerTypeReplacement(field.FieldType, this.TypeDef, out TypeReference newFieldType))
{
Debug.WriteLine($"......... [-] field '{field}'");
field.FieldType = newFieldType;
Debug.WriteLine($"......... [+] field '{field}'");
}
}
internal override void VisitMethod(MethodDefinition method)
{
this.Method = null;
// Only non-abstract method bodies can be rewritten.
if (!method.IsAbstract)
{
this.Method = method;
this.Processor = method.Body.GetILProcessor();
}
// bugbug: what if this is an override of an inherited virtual method? For example, what if there
// is an external base class that is a Task like type that implements a virtual GetAwaiter() that
// is overridden by this method?
if (this.TryGetCompilerTypeReplacement(method.ReturnType, method, out TypeReference newReturnType))
{
Debug.WriteLine($"........... [-] return type '{method.ReturnType}'");
method.ReturnType = newReturnType;
Debug.WriteLine($"........... [+] return type '{method.ReturnType}'");
}
}
internal override void VisitExceptionHandler(ExceptionHandler handler)
{
}
internal override void VisitVariable(VariableDefinition variable)
{
if (this.Method == null)
{
return;
}
if (this.TryGetCompilerTypeReplacement(variable.VariableType, this.Method, out TypeReference newVariableType))
{
Debug.WriteLine($"........... [-] variable '{variable.VariableType}'");
variable.VariableType = newVariableType;
Debug.WriteLine($"........... [+] variable '{variable.VariableType}'");
}
}
internal override Instruction VisitInstruction(Instruction instruction)
{
if (this.Method == null)
{
return instruction;
}
// TODO: what about ldsfld, for static fields?
if (instruction.OpCode == OpCodes.Stfld || instruction.OpCode == OpCodes.Ldfld || instruction.OpCode == OpCodes.Ldflda)
{
if (instruction.Operand is FieldDefinition fd &&
this.TryGetCompilerTypeReplacement(fd.FieldType, this.Method, out TypeReference newFieldType))
{
Debug.WriteLine($"........... [-] {instruction}");
fd.FieldType = newFieldType;
Debug.WriteLine($"........... [+] {instruction}");
}
else if (instruction.Operand is FieldReference fr &&
this.TryGetCompilerTypeReplacement(fr.FieldType, this.Method, out newFieldType))
{
Debug.WriteLine($"........... [-] {instruction}");
fr.FieldType = newFieldType;
Debug.WriteLine($"........... [+] {instruction}");
}
}
else if (instruction.OpCode == OpCodes.Initobj)
{
instruction = this.VisitInitobjProcessingOperation(instruction, this.Method);
}
else if ((instruction.OpCode == OpCodes.Call || instruction.OpCode == OpCodes.Callvirt) &&
instruction.Operand is MethodReference methodReference)
{
instruction = this.VisitCallProcessingOperation(instruction, methodReference, this.Method);
}
// return the last modified instruction or the same one if it was not changed.
return instruction;
}
/// <summary>
/// Transform the specified non-generic <see cref="OpCodes.Call"/> or <see cref="OpCodes.Callvirt"/> instruction.
/// </summary>
/// <returns>The unmodified instruction or the last newly inserted instruction.</returns>
private Instruction VisitCallProcessingOperation(Instruction instruction, MethodReference method, IGenericParameterProvider provider)
{
TypeReference newType = null;
var opCode = instruction.OpCode;
if (IsSystemTaskType(method.DeclaringType))
{
// Special rules apply for methods under the Task namespace.
if (method.Name == nameof(SystemTasks.Task.Run) ||
method.Name == nameof(SystemTasks.Task.Delay) ||
method.Name == nameof(SystemTasks.Task.WhenAll) ||
method.Name == nameof(SystemTasks.Task.WhenAny) ||
method.Name == nameof(SystemTasks.Task.Yield) ||
method.Name == nameof(SystemTasks.Task.GetAwaiter))
{
newType = this.GetTaskTypeReplacement(method.DeclaringType);
}
}
else
{
newType = this.GetCompilerTypeReplacement(method.DeclaringType, provider);
if (newType == method.DeclaringType)
{
newType = null;
}
}
// TODO: check if "new type is null" check is required.
if (newType is null || !this.TryGetReplacementMethod(newType, method, out MethodReference newMethod))
{
// There is nothing to rewrite, return with the none operation.
return instruction;
}
OpCode newOpCode = opCode;
if (newMethod.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter))
{
// The OpCode must change for the GetAwaiter method.
newOpCode = OpCodes.Call;
}
// Create and return the new instruction.
Instruction newInstruction = Instruction.Create(newOpCode, newMethod);
Debug.WriteLine($"........... [-] {instruction}");
this.Processor.Replace(instruction, newInstruction);
Debug.WriteLine($"........... [+] {newInstruction}");
return newInstruction;
}
/// <summary>
/// Transform the specified <see cref="OpCodes.Initobj"/> instruction.
/// </summary>
/// <remarks>
/// Return the unmodified instruction, or the newly replaced instruction.
/// </remarks>
private Instruction VisitInitobjProcessingOperation(Instruction instruction, MethodDefinition method)
{
TypeReference type = instruction.Operand as TypeReference;
if (this.TryGetCompilerTypeReplacement(type, method, out TypeReference newType))
{
var newInstruction = Instruction.Create(instruction.OpCode, newType);
Debug.WriteLine($"........... [-] {instruction}");
this.Processor.Replace(instruction, newInstruction);
Debug.WriteLine($"........... [+] {newInstruction}");
instruction = newInstruction;
}
return instruction;
}
/// <summary>
/// Returns a method from the specified type that can replace the original method, if any.
/// </summary>
private bool TryGetReplacementMethod(TypeReference replacementType, MethodReference originalMethod, out MethodReference result)
{
result = null;
TypeDefinition replacementTypeDef = replacementType.Resolve();
if (replacementTypeDef == null)
{
throw new Exception(string.Format("Error resolving type: {0}", replacementType.FullName));
}
MethodDefinition original = originalMethod.Resolve();
bool isGetControlledAwaiter = false;
foreach (var replacement in replacementTypeDef.Methods)
{
// TODO: make sure all necessary checks are in place!
if (!(!replacement.IsConstructor &&
replacement.Name == original.Name &&
replacement.ReturnType.IsGenericInstance == original.ReturnType.IsGenericInstance &&
replacement.IsPublic == original.IsPublic &&
replacement.IsPrivate == original.IsPrivate &&
replacement.IsAssembly == original.IsAssembly &&
replacement.IsFamilyAndAssembly == original.IsFamilyAndAssembly))
{
continue;
}
isGetControlledAwaiter = replacement.DeclaringType.Namespace == KnownNamespaces.ControlledTasksName &&
replacement.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter);
if (!isGetControlledAwaiter)
{
// Only check that the parameters match for non-controlled awaiter methods.
// That is because we do special rewriting for this method.
if (!CheckMethodParametersMatch(replacement, original))
{
continue;
}
}
// Import the method in the module that is being rewritten.
result = originalMethod.Module.ImportReference(replacement);
break;
}
if (result is null)
{
// TODO: raise an error.
return false;
}
result.DeclaringType = replacementType;
if (originalMethod is GenericInstanceMethod genericMethod)
{
var newGenericMethod = new GenericInstanceMethod(result);
// The method is generic, so populate it with generic argument types and parameters.
foreach (var arg in genericMethod.GenericArguments)
{
TypeReference newArgumentType = this.GetCompilerTypeReplacement(arg, newGenericMethod);
newGenericMethod.GenericArguments.Add(newArgumentType);
}
result = newGenericMethod;
}
else if (isGetControlledAwaiter && originalMethod.DeclaringType is GenericInstanceType genericType)
{
// Special processing applies in this case, because we are converting the `Task<T>.GetAwaiter`
// non-generic instance method to the `ControlledTask.GetAwaiter<T>` generic static method.
var newGenericMethod = new GenericInstanceMethod(result);
// There is only a single argument type, which must be added to ma.
newGenericMethod.GenericArguments.Add(this.GetCompilerTypeReplacement(genericType.GenericArguments[0], newGenericMethod));
// The single generic argument type in the task parameter must be rewritten to the same
// generic argument type as the one in the return type of `GetAwaiter<T>`.
var parameterType = newGenericMethod.Parameters[0].ParameterType as GenericInstanceType;
parameterType.GenericArguments[0] = (newGenericMethod.ReturnType as GenericInstanceType).GenericArguments[0];
result = newGenericMethod;
}
// Rewrite the parameters of the method, if any.
for (int idx = 0; idx < originalMethod.Parameters.Count; idx++)
{
ParameterDefinition parameter = originalMethod.Parameters[idx];
TypeReference newParameterType = this.GetCompilerTypeReplacement(parameter.ParameterType, result);
ParameterDefinition newParameter = new ParameterDefinition(parameter.Name, parameter.Attributes, newParameterType);
result.Parameters[idx] = newParameter;
}
if (result.ReturnType.Namespace != KnownNamespaces.ControlledTasksName)
{
result.ReturnType = this.GetCompilerTypeReplacement(originalMethod.ReturnType, result);
}
return originalMethod.FullName != result.FullName;
}
/// <summary>
/// Checks if the parameters of the two methods match.
/// </summary>
private static bool CheckMethodParametersMatch(MethodDefinition left, MethodDefinition right)
{
if (left.Parameters.Count != right.Parameters.Count)
{
return false;
}
for (int idx = 0; idx < right.Parameters.Count; idx++)
{
var originalParam = right.Parameters[0];
var replacementParam = left.Parameters[0];
// TODO: make sure all necessary checks are in place!
if ((replacementParam.ParameterType.FullName != originalParam.ParameterType.FullName) ||
(replacementParam.Name != originalParam.Name) ||
(replacementParam.IsIn && !originalParam.IsIn) ||
(replacementParam.IsOut && !originalParam.IsOut))
{
return false;
}
}
return true;
}
/// <summary>
/// Returns the replacement type for the specified <see cref="SystemTasks"/> type, else null.
/// </summary>
private TypeReference GetTaskTypeReplacement(TypeReference type)
{
TypeReference result;
string fullName = type.FullName;
if (this.TaskTypeCache.ContainsKey(fullName))
{
result = this.TaskTypeCache[fullName];
if (result.Module != type.Module)
{
result = type.Module.ImportReference(result);
this.TaskTypeCache[fullName] = result;
}
}
else
{
result = type.Module.ImportReference(typeof(ControlledTasks.ControlledTask));
this.TaskTypeCache[fullName] = result;
}
return result;
}
/// <summary>
/// Tries to return the replacement type for the specified <see cref="SystemCompiler"/> type, if such a type exists.
/// </summary>
private bool TryGetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider, out TypeReference result)
{
result = this.GetCompilerTypeReplacement(type, provider);
return result.FullName != type.FullName;
}
/// <summary>
/// Returns the replacement type for the specified <see cref="SystemCompiler"/> type, else null.
/// </summary>
private TypeReference GetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider)
{
TypeReference result = type;
string fullName = type.FullName;
if (this.CompilerTypeCache.ContainsKey(fullName))
{
result = this.CompilerTypeCache[fullName];
if (result.Module != type.Module)
{
result = type.Module.ImportReference(result);
this.CompilerTypeCache[fullName] = result;
}
}
else if (type.IsGenericInstance &&
(type.Name == KnownSystemTypes.GenericAsyncTaskMethodBuilderName ||
type.Name == KnownSystemTypes.GenericTaskAwaiterName))
{
result = this.GetGenericTypeReplacement(type as GenericInstanceType, provider);
if (result.FullName != fullName)
{
result = type.Module.ImportReference(result);
this.CompilerTypeCache[fullName] = result;
}
}
else if (fullName == KnownSystemTypes.AsyncTaskMethodBuilderFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder));
}
else if (fullName == KnownSystemTypes.GenericAsyncTaskMethodBuilderFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder<>));
}
else if (fullName == KnownSystemTypes.TaskAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter));
}
else if (fullName == KnownSystemTypes.GenericTaskAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter<>));
}
else if (fullName == KnownSystemTypes.YieldAwaitableFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable));
}
else if (fullName == KnownSystemTypes.YieldAwaiterFullName)
{
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable.YieldAwaiter));
}
return result;
}
/// <summary>
/// Import the replacement coyote type and cache it in the CompilerTypeCache and make sure the type can be
/// fully resolved.
/// </summary>
private TypeReference ImportCompilerTypeReplacement(TypeReference originalType, Type coyoteType)
{
var result = originalType.Module.ImportReference(coyoteType);
this.CompilerTypeCache[originalType.FullName] = result;
TypeDefinition coyoteTypeDef = result.Resolve();
if (coyoteTypeDef == null)
{
throw new Exception(string.Format("Unexpected error resolving type: {0}", coyoteType.FullName));
}
return result;
}
/// <summary>
/// Returns the replacement type for the specified generic type, else null.
/// </summary>
private GenericInstanceType GetGenericTypeReplacement(GenericInstanceType type, IGenericParameterProvider provider)
{
GenericInstanceType result = type;
TypeReference genericType = this.GetCompilerTypeReplacement(type.ElementType, null);
if (type.ElementType.FullName != genericType.FullName)
{
// The generic type must be rewritten.
result = new GenericInstanceType(type.Module.ImportReference(genericType));
foreach (var arg in type.GenericArguments)
{
TypeReference newArgumentType;
if (arg.IsGenericParameter)
{
GenericParameter parameter = new GenericParameter(arg.Name, provider ?? result);
result.GenericParameters.Add(parameter);
newArgumentType = parameter;
}
else
{
newArgumentType = this.GetCompilerTypeReplacement(arg, provider);
}
result.GenericArguments.Add(newArgumentType);
}
}
return result;
}
/// <summary>
/// Checks if the specified type is the <see cref="SystemTasks.Task"/> type.
/// </summary>
private static bool IsSystemTaskType(TypeReference type) => type.Namespace == KnownNamespaces.SystemTasksName &&
(type.Name == typeof(SystemTasks.Task).Name || type.Name.StartsWith(GenericTaskTypeNamePrefix));
/// <summary>
/// Cache of known <see cref="SystemCompiler"/> type names.
/// </summary>
private static class KnownSystemTypes
{
internal static string AsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder).FullName;
internal static string GenericAsyncTaskMethodBuilderName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).Name;
internal static string GenericAsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).FullName;
internal static string TaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter).FullName;
internal static string GenericTaskAwaiterName { get; } = typeof(SystemCompiler.TaskAwaiter<>).Name;
internal static string GenericTaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter<>).FullName;
internal static string YieldAwaitableFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName;
internal static string YieldAwaiterFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName + "/YieldAwaiter";
}
/// <summary>
/// Cache of known namespace names.
/// </summary>
private static class KnownNamespaces
{
internal static string ControlledTasksName { get; } = typeof(ControlledTasks.ControlledTask).Namespace;
internal static string SystemTasksName { get; } = typeof(SystemTasks.Task).Namespace;
}
}
}

Просмотреть файл

@ -16,4 +16,7 @@
<ItemGroup>
<ProjectReference Include="..\..\Source\Core\Core.csproj" />
</ItemGroup>
<ItemGroup>
<Folder Include="Transforms\" />
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Coyote.Specifications;
using Xunit;
using Xunit.Abstractions;
using Monitor = System.Threading.Monitor;
namespace Microsoft.Coyote.BinaryRewriting.Tests.Tasks.Locks
{
public class CSharpLockTests : BaseProductionTest
{
private readonly object SyncObject1 = new object();
private string Value;
public CSharpLockTests(ITestOutputHelper output)
: base(output)
{
}
[Fact(Timeout = 5000)]
public void TestSimpleLock()
{
this.Test(() =>
{
lock (this.SyncObject1)
{
this.Value = "1";
this.TestReentrancy();
}
var expected = "2";
Specification.Assert(this.Value == expected, "Value is {0} instead of {1}.", this.Value, expected);
});
}
private void TestReentrancy()
{
lock (this.SyncObject1)
{
this.Value = "2";
}
}
[Fact(Timeout = 5000)]
public void TestWaitPulse()
{
this.Test(async () =>
{
var t1 = Task.Run(this.TakeTask);
var t2 = Task.Run(this.PutTask);
await Task.WhenAll(t1, t2);
var expected = "taken";
Specification.Assert(this.Value == expected, "Value is {0} instead of {1}.", this.Value, expected);
});
}
private void TakeTask()
{
lock (this.SyncObject1)
{
if (this.Value != "put")
{
Monitor.Wait(this.SyncObject1);
}
this.Value = "taken";
}
}
private void PutTask()
{
lock (this.SyncObject1)
{
this.Value = "put";
Monitor.Pulse(this.SyncObject1);
}
}
}
}