Merged PR 2644: add a visitor pattern for code rewriting transforms.
add a visitor pattern for code rewriting transforms. Related work items: #4442, #4496, #4644
This commit is contained in:
Родитель
4505028803
Коммит
e9a8a22771
|
@ -24,6 +24,11 @@ namespace Microsoft.Coyote.Tasks
|
|||
/// </summary>
|
||||
protected readonly object SyncObject;
|
||||
|
||||
/// <summary>
|
||||
/// Whether the lock was taken.
|
||||
/// </summary>
|
||||
private bool LockTaken;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="SynchronizedBlock"/> class.
|
||||
/// </summary>
|
||||
|
@ -47,7 +52,7 @@ namespace Microsoft.Coyote.Tasks
|
|||
/// <returns>The synchronized block.</returns>
|
||||
protected virtual SynchronizedBlock EnterLock()
|
||||
{
|
||||
SystemMonitor.Enter(this.SyncObject);
|
||||
SystemMonitor.Enter(this.SyncObject, ref this.LockTaken);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -97,7 +102,7 @@ namespace Microsoft.Coyote.Tasks
|
|||
/// </summary>
|
||||
protected virtual void Dispose(bool disposing)
|
||||
{
|
||||
if (disposing)
|
||||
if (disposing && this.LockTaken)
|
||||
{
|
||||
SystemMonitor.Exit(this.SyncObject);
|
||||
}
|
||||
|
|
|
@ -8,10 +8,7 @@ using System.Linq;
|
|||
using Microsoft.Coyote.IO;
|
||||
using Mono.Cecil;
|
||||
using Mono.Cecil.Cil;
|
||||
using ControlledTasks = Microsoft.Coyote.SystematicTesting.Interception;
|
||||
using CoyoteTasks = Microsoft.Coyote.Tasks;
|
||||
using SystemCompiler = System.Runtime.CompilerServices;
|
||||
using SystemTasks = System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
|
@ -35,21 +32,9 @@ namespace Microsoft.Coyote.Rewriting
|
|||
private readonly Configuration Configuration;
|
||||
|
||||
/// <summary>
|
||||
/// Cache from <see cref="SystemTasks"/> type names to types being replaced
|
||||
/// in the module that is currently being rewritten.
|
||||
/// List of transforms we are applying while rewriting.
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, TypeReference> TaskTypeCache;
|
||||
|
||||
/// <summary>
|
||||
/// Cache from <see cref="SystemCompiler"/> type names to types being replaced
|
||||
/// in the module that is currently being rewritten.
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, TypeReference> CompilerTypeCache;
|
||||
|
||||
/// <summary>
|
||||
/// Cached generic task type name prefix.
|
||||
/// </summary>
|
||||
private const string GenericTaskTypeNamePrefix = "Task`";
|
||||
private readonly List<AssemblyTransform> Transforms = new List<AssemblyTransform>();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="AssemblyRewriter"/> class.
|
||||
|
@ -58,8 +43,8 @@ namespace Microsoft.Coyote.Rewriting
|
|||
private AssemblyRewriter(Configuration configuration)
|
||||
{
|
||||
this.Configuration = configuration;
|
||||
this.TaskTypeCache = new Dictionary<string, TypeReference>();
|
||||
this.CompilerTypeCache = new Dictionary<string, TypeReference>();
|
||||
this.Transforms.Add(new TaskTransform());
|
||||
this.Transforms.Add(new LockTransform());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -167,8 +152,11 @@ namespace Microsoft.Coyote.Rewriting
|
|||
/// </summary>
|
||||
private void RewriteModule(ModuleDefinition module)
|
||||
{
|
||||
this.TaskTypeCache.Clear();
|
||||
this.CompilerTypeCache.Clear();
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
t.VisitModule(module);
|
||||
}
|
||||
|
||||
foreach (var type in module.GetTypes())
|
||||
{
|
||||
this.RewriteType(type);
|
||||
|
@ -181,13 +169,17 @@ namespace Microsoft.Coyote.Rewriting
|
|||
private void RewriteType(TypeDefinition type)
|
||||
{
|
||||
Debug.WriteLine($"....... Rewriting type '{type.FullName}'");
|
||||
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
t.VisitType(type);
|
||||
}
|
||||
|
||||
foreach (var field in type.Fields)
|
||||
{
|
||||
if (this.TryGetCompilerTypeReplacement(field.FieldType, null, out TypeReference newFieldType))
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
Debug.WriteLine($"......... [-] field '{field}'");
|
||||
field.FieldType = newFieldType;
|
||||
Debug.WriteLine($"......... [+] field '{field}'");
|
||||
t.VisitField(field);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,464 +196,49 @@ namespace Microsoft.Coyote.Rewriting
|
|||
{
|
||||
Debug.WriteLine($"......... Rewriting method '{method.FullName}'");
|
||||
|
||||
if (this.TryGetCompilerTypeReplacement(method.ReturnType, method, out TypeReference newReturnType))
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
Debug.WriteLine($"........... [-] return type '{method.ReturnType}'");
|
||||
method.ReturnType = newReturnType;
|
||||
Debug.WriteLine($"........... [+] return type '{method.ReturnType}'");
|
||||
t.VisitMethod(method);
|
||||
}
|
||||
|
||||
// Only non-abstract method bodies can be rewritten.
|
||||
if (!method.IsAbstract)
|
||||
{
|
||||
// TODO: Check if method.Body.ExceptionHandlers needs to be rewritten.
|
||||
|
||||
ILProcessor processor = method.Body.GetILProcessor();
|
||||
foreach (var variable in method.Body.Variables)
|
||||
{
|
||||
if (this.TryGetCompilerTypeReplacement(variable.VariableType, method, out TypeReference newVariableType))
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
Debug.WriteLine($"........... [-] variable '{variable.VariableType}'");
|
||||
variable.VariableType = newVariableType;
|
||||
Debug.WriteLine($"........... [+] variable '{variable.VariableType}'");
|
||||
t.VisitVariable(variable);
|
||||
}
|
||||
}
|
||||
|
||||
Instruction instruction = method.Body.Instructions.FirstOrDefault();
|
||||
while (instruction != null)
|
||||
// Do exception handlers before the method instructions because they are a
|
||||
// higher level concept and it's handy to pre-process them before seeing the
|
||||
// raw instructions.
|
||||
if (method.Body.HasExceptionHandlers)
|
||||
{
|
||||
if (instruction.OpCode == OpCodes.Stfld || instruction.OpCode == OpCodes.Ldfld || instruction.OpCode == OpCodes.Ldflda)
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
if (instruction.Operand is FieldDefinition fd &&
|
||||
this.TryGetCompilerTypeReplacement(fd.FieldType, method, out TypeReference newFieldType))
|
||||
foreach (var handler in method.Body.ExceptionHandlers)
|
||||
{
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
fd.FieldType = newFieldType;
|
||||
Debug.WriteLine($"........... [+] {instruction}");
|
||||
}
|
||||
else if (instruction.Operand is FieldReference fr &&
|
||||
this.TryGetCompilerTypeReplacement(fr.FieldType, method, out newFieldType))
|
||||
{
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
fr.FieldType = newFieldType;
|
||||
Debug.WriteLine($"........... [+] {instruction}");
|
||||
t.VisitExceptionHandler(handler);
|
||||
}
|
||||
}
|
||||
else
|
||||
}
|
||||
|
||||
// in this case run each transform as separate passes over the method body
|
||||
// so they don't trip over each other's edits.
|
||||
foreach (var t in this.Transforms)
|
||||
{
|
||||
Instruction instruction = method.Body.Instructions.FirstOrDefault();
|
||||
while (instruction != null)
|
||||
{
|
||||
ILProcessingOperation operation = this.GetILProcessingOperation(instruction, method);
|
||||
if (operation.Type is ILProcessingOperationType.Replace)
|
||||
{
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
operation.Instructions[0].Offset = instruction.Offset;
|
||||
processor.Replace(instruction, operation.Instructions[0]);
|
||||
instruction = operation.Instructions[0];
|
||||
|
||||
Debug.WriteLine($"........... [+] {instruction}");
|
||||
for (int idx = 1; idx < operation.Instructions.Length; idx++)
|
||||
{
|
||||
Debug.WriteLine($"........... [+] {operation.Instructions[idx]}");
|
||||
processor.InsertAfter(instruction, operation.Instructions[idx]);
|
||||
instruction = instruction.Next;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instruction = instruction.Next;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns an <see cref="ILProcessingOperation"/> for the specified instruction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
|
||||
/// </remarks>
|
||||
private ILProcessingOperation GetILProcessingOperation(Instruction instruction, MethodDefinition method)
|
||||
{
|
||||
// TODO: check what we need to deal with `OpCodes.Calli`, and if we need to.
|
||||
ILProcessingOperation operation = ILProcessingOperation.None;
|
||||
if (instruction.OpCode == OpCodes.Initobj)
|
||||
{
|
||||
operation = this.GetInitobjProcessingOperation(instruction, method);
|
||||
}
|
||||
else if ((instruction.OpCode == OpCodes.Call || instruction.OpCode == OpCodes.Callvirt) &&
|
||||
instruction.Operand is MethodReference methodReference)
|
||||
{
|
||||
operation = this.GetCallProcessingOperation(instruction.OpCode, methodReference, method);
|
||||
}
|
||||
|
||||
return operation;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns an <see cref="ILProcessingOperation"/> for the specified <see cref="OpCodes.Initobj"/> instruction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
|
||||
/// </remarks>
|
||||
private ILProcessingOperation GetInitobjProcessingOperation(Instruction instruction, MethodDefinition method)
|
||||
{
|
||||
ILProcessingOperation operation = ILProcessingOperation.None;
|
||||
TypeReference type = instruction.Operand as TypeReference;
|
||||
if (this.TryGetCompilerTypeReplacement(type, method, out TypeReference newType))
|
||||
{
|
||||
var newInstruction = Instruction.Create(instruction.OpCode, newType);
|
||||
operation = new ILProcessingOperation(ILProcessingOperationType.Replace, newInstruction);
|
||||
}
|
||||
|
||||
return operation;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns an <see cref="ILProcessingOperation"/> for the specified non-generic <see cref="OpCodes.Call"/>
|
||||
/// or <see cref="OpCodes.Callvirt"/> instruction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// If the returned operation has type <see cref="ILProcessingOperationType.None"/>, then there is nothing to rewrite.
|
||||
/// </remarks>
|
||||
private ILProcessingOperation GetCallProcessingOperation(OpCode opCode, MethodReference method, IGenericParameterProvider provider)
|
||||
{
|
||||
TypeReference newType = null;
|
||||
if (IsSystemTaskType(method.DeclaringType))
|
||||
{
|
||||
// Special rules apply for methods under the Task namespace.
|
||||
if (method.Name == nameof(SystemTasks.Task.Run) ||
|
||||
method.Name == nameof(SystemTasks.Task.Delay) ||
|
||||
method.Name == nameof(SystemTasks.Task.WhenAll) ||
|
||||
method.Name == nameof(SystemTasks.Task.WhenAny) ||
|
||||
method.Name == nameof(SystemTasks.Task.Yield) ||
|
||||
method.Name == nameof(SystemTasks.Task.GetAwaiter))
|
||||
{
|
||||
newType = this.GetTaskTypeReplacement(method.DeclaringType);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
newType = this.GetCompilerTypeReplacement(method.DeclaringType, provider);
|
||||
if (newType == method.DeclaringType)
|
||||
{
|
||||
newType = null;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check if "new type is null" check is required.
|
||||
if (newType is null || !this.TryGetReplacementMethod(newType, method, out MethodReference newMethod))
|
||||
{
|
||||
// There is nothing to rewrite, return with the none operation.
|
||||
return ILProcessingOperation.None;
|
||||
}
|
||||
|
||||
OpCode newOpCode = opCode;
|
||||
if (newMethod.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter))
|
||||
{
|
||||
// The OpCode must change for the GetAwaiter method.
|
||||
newOpCode = OpCodes.Call;
|
||||
}
|
||||
|
||||
// Create and return the new instruction.
|
||||
Instruction newInstruction = Instruction.Create(newOpCode, newMethod);
|
||||
return new ILProcessingOperation(ILProcessingOperationType.Replace, newInstruction);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a method from the specified type that can replace the original method, if any.
|
||||
/// </summary>
|
||||
private bool TryGetReplacementMethod(TypeReference replacementType, MethodReference originalMethod, out MethodReference result)
|
||||
{
|
||||
result = null;
|
||||
|
||||
TypeDefinition replacementTypeDef = replacementType.Resolve();
|
||||
if (replacementTypeDef == null)
|
||||
{
|
||||
throw new Exception(string.Format("Error resolving type: {0}", replacementType.FullName));
|
||||
}
|
||||
|
||||
MethodDefinition original = originalMethod.Resolve();
|
||||
|
||||
bool isGetControlledAwaiter = false;
|
||||
foreach (var replacement in replacementTypeDef.Methods)
|
||||
{
|
||||
// TODO: make sure all necessery checks are in place!
|
||||
if (!(!replacement.IsConstructor &&
|
||||
replacement.Name == original.Name &&
|
||||
replacement.ReturnType.IsGenericInstance == original.ReturnType.IsGenericInstance &&
|
||||
replacement.IsPublic == original.IsPublic &&
|
||||
replacement.IsPrivate == original.IsPrivate &&
|
||||
replacement.IsAssembly == original.IsAssembly &&
|
||||
replacement.IsFamilyAndAssembly == original.IsFamilyAndAssembly))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
isGetControlledAwaiter = replacement.DeclaringType.Namespace == KnownNamespaces.ControlledTasksName &&
|
||||
replacement.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter);
|
||||
if (!isGetControlledAwaiter)
|
||||
{
|
||||
// Only check that the parameters match for non-controlled awaiter methods.
|
||||
// That is because we do special rewriting for this method.
|
||||
if (!CheckMethodParametersMatch(replacement, original))
|
||||
{
|
||||
continue;
|
||||
instruction = t.VisitInstruction(instruction);
|
||||
instruction = instruction.Next;
|
||||
}
|
||||
}
|
||||
|
||||
// Import the method in the module that is being rewritten.
|
||||
result = originalMethod.Module.ImportReference(replacement);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result is null)
|
||||
{
|
||||
// TODO: raise an error.
|
||||
return false;
|
||||
}
|
||||
|
||||
result.DeclaringType = replacementType;
|
||||
if (originalMethod is GenericInstanceMethod genericMethod)
|
||||
{
|
||||
var newGenericMethod = new GenericInstanceMethod(result);
|
||||
|
||||
// The method is generic, so populate it with generic argument types and parameters.
|
||||
foreach (var arg in genericMethod.GenericArguments)
|
||||
{
|
||||
TypeReference newArgumentType = this.GetCompilerTypeReplacement(arg, newGenericMethod);
|
||||
newGenericMethod.GenericArguments.Add(newArgumentType);
|
||||
}
|
||||
|
||||
result = newGenericMethod;
|
||||
}
|
||||
else if (isGetControlledAwaiter && originalMethod.DeclaringType is GenericInstanceType genericType)
|
||||
{
|
||||
// Special processing applies in this case, because we are converting the `Task<T>.GetAwaiter`
|
||||
// non-generic instance method to the `ControlledTask.GetAwaiter<T>` generic static method.
|
||||
var newGenericMethod = new GenericInstanceMethod(result);
|
||||
|
||||
// There is only a single argument type, which must be added to ma.
|
||||
newGenericMethod.GenericArguments.Add(this.GetCompilerTypeReplacement(genericType.GenericArguments[0], newGenericMethod));
|
||||
|
||||
// The single generic argument type in the task parameter must be rewritten to the same
|
||||
// generic argument type as the one in the return type of `GetAwaiter<T>`.
|
||||
var parameterType = newGenericMethod.Parameters[0].ParameterType as GenericInstanceType;
|
||||
parameterType.GenericArguments[0] = (newGenericMethod.ReturnType as GenericInstanceType).GenericArguments[0];
|
||||
|
||||
result = newGenericMethod;
|
||||
}
|
||||
|
||||
// Rewrite the parameters of the method, if any.
|
||||
for (int idx = 0; idx < originalMethod.Parameters.Count; idx++)
|
||||
{
|
||||
ParameterDefinition parameter = originalMethod.Parameters[idx];
|
||||
TypeReference newParameterType = this.GetCompilerTypeReplacement(parameter.ParameterType, result);
|
||||
ParameterDefinition newParameter = new ParameterDefinition(parameter.Name, parameter.Attributes, newParameterType);
|
||||
result.Parameters[idx] = newParameter;
|
||||
}
|
||||
|
||||
if (result.ReturnType.Namespace != KnownNamespaces.ControlledTasksName)
|
||||
{
|
||||
result.ReturnType = this.GetCompilerTypeReplacement(originalMethod.ReturnType, result);
|
||||
}
|
||||
|
||||
return originalMethod.FullName != result.FullName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if the the parameters of the two methods match.
|
||||
/// </summary>
|
||||
private static bool CheckMethodParametersMatch(MethodDefinition left, MethodDefinition right)
|
||||
{
|
||||
if (left.Parameters.Count != right.Parameters.Count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < right.Parameters.Count; idx++)
|
||||
{
|
||||
var originalParam = right.Parameters[0];
|
||||
var replacementParam = left.Parameters[0];
|
||||
// TODO: make sure all necessery checks are in place!
|
||||
if ((replacementParam.ParameterType.FullName != originalParam.ParameterType.FullName) ||
|
||||
(replacementParam.Name != originalParam.Name) ||
|
||||
(replacementParam.IsIn && !originalParam.IsIn) ||
|
||||
(replacementParam.IsOut && !originalParam.IsOut))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified <see cref="SystemTasks"/> type, else null.
|
||||
/// </summary>
|
||||
private TypeReference GetTaskTypeReplacement(TypeReference type)
|
||||
{
|
||||
TypeReference result;
|
||||
|
||||
string fullName = type.FullName;
|
||||
if (this.TaskTypeCache.ContainsKey(fullName))
|
||||
{
|
||||
result = this.TaskTypeCache[fullName];
|
||||
if (result.Module != type.Module)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.TaskTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result = type.Module.ImportReference(typeof(ControlledTasks.ControlledTask));
|
||||
this.TaskTypeCache[fullName] = result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to return the replacement type for the specified <see cref="SystemCompiler"/> type, if such a type exists.
|
||||
/// </summary>
|
||||
private bool TryGetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider, out TypeReference result)
|
||||
{
|
||||
result = this.GetCompilerTypeReplacement(type, provider);
|
||||
return result.FullName != type.FullName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified <see cref="SystemCompiler"/> type, else null.
|
||||
/// </summary>
|
||||
private TypeReference GetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider)
|
||||
{
|
||||
TypeReference result = type;
|
||||
|
||||
string fullName = type.FullName;
|
||||
if (this.CompilerTypeCache.ContainsKey(fullName))
|
||||
{
|
||||
result = this.CompilerTypeCache[fullName];
|
||||
if (result.Module != type.Module)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.CompilerTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else if (type.IsGenericInstance &&
|
||||
(type.Name == KnownSystemTypes.GenericAsyncTaskMethodBuilderName ||
|
||||
type.Name == KnownSystemTypes.GenericTaskAwaiterName))
|
||||
{
|
||||
result = this.GetGenericTypeReplacement(type as GenericInstanceType, provider);
|
||||
if (result.FullName != fullName)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.CompilerTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.AsyncTaskMethodBuilderFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.GenericAsyncTaskMethodBuilderFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder<>));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.TaskAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.GenericTaskAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter<>));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.YieldAwaitableFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.YieldAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable.YieldAwaiter));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Import the replacement coyote type and cache it in the CompilerTypeCache and make sure the type can be
|
||||
/// fully resolved.
|
||||
/// </summary>
|
||||
private TypeReference ImportCompilerTypeReplacement(TypeReference originalType, Type coyoteType)
|
||||
{
|
||||
var result = originalType.Module.ImportReference(coyoteType);
|
||||
this.CompilerTypeCache[originalType.FullName] = result;
|
||||
|
||||
TypeDefinition coyoteTypeDef = result.Resolve();
|
||||
if (coyoteTypeDef == null)
|
||||
{
|
||||
throw new Exception(string.Format("Unexpected error resolving type: {0}", coyoteType.FullName));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified generic type, else null.
|
||||
/// </summary>
|
||||
private GenericInstanceType GetGenericTypeReplacement(GenericInstanceType type, IGenericParameterProvider provider)
|
||||
{
|
||||
GenericInstanceType result = type;
|
||||
TypeReference genericType = this.GetCompilerTypeReplacement(type.ElementType, null);
|
||||
if (type.ElementType.FullName != genericType.FullName)
|
||||
{
|
||||
// The generic type must be rewritten.
|
||||
result = new GenericInstanceType(type.Module.ImportReference(genericType));
|
||||
foreach (var arg in type.GenericArguments)
|
||||
{
|
||||
TypeReference newArgumentType;
|
||||
if (arg.IsGenericParameter)
|
||||
{
|
||||
GenericParameter parameter = new GenericParameter(arg.Name, provider ?? result);
|
||||
result.GenericParameters.Add(parameter);
|
||||
newArgumentType = parameter;
|
||||
}
|
||||
else
|
||||
{
|
||||
newArgumentType = this.GetCompilerTypeReplacement(arg, provider);
|
||||
}
|
||||
|
||||
result.GenericArguments.Add(newArgumentType);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if the specified type is the <see cref="SystemTasks.Task"/> type.
|
||||
/// </summary>
|
||||
private static bool IsSystemTaskType(TypeReference type) => type.Namespace == KnownNamespaces.SystemTasksName &&
|
||||
(type.Name == typeof(SystemTasks.Task).Name || type.Name.StartsWith(GenericTaskTypeNamePrefix));
|
||||
|
||||
/// <summary>
|
||||
/// Cache of known <see cref="SystemCompiler"/> type names.
|
||||
/// </summary>
|
||||
private static class KnownSystemTypes
|
||||
{
|
||||
internal static string AsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder).FullName;
|
||||
internal static string GenericAsyncTaskMethodBuilderName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).Name;
|
||||
internal static string GenericAsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).FullName;
|
||||
internal static string TaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter).FullName;
|
||||
internal static string GenericTaskAwaiterName { get; } = typeof(SystemCompiler.TaskAwaiter<>).Name;
|
||||
internal static string GenericTaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter<>).FullName;
|
||||
internal static string YieldAwaitableFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName;
|
||||
internal static string YieldAwaiterFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName + "/YieldAwaiter";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cache of known namespace names.
|
||||
/// </summary>
|
||||
private static class KnownNamespaces
|
||||
{
|
||||
internal static string ControlledTasksName { get; } = typeof(ControlledTasks.ControlledTask).Namespace;
|
||||
internal static string SystemTasksName { get; } = typeof(SystemTasks.Task).Namespace;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using Mono.Cecil.Cil;
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
/// <summary>
|
||||
/// An IL processing operation to perform.
|
||||
/// </summary>
|
||||
internal struct ILProcessingOperation
|
||||
{
|
||||
/// <summary>
|
||||
/// A cached no-op IL processing operation.
|
||||
/// </summary>
|
||||
internal static ILProcessingOperation None { get; } = new ILProcessingOperation(ILProcessingOperationType.None);
|
||||
|
||||
/// <summary>
|
||||
/// A cached array containing a single instruction to optimize for memory in the common scenario.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// This is not thread safe, but we are not running rewriting in parallel.
|
||||
/// </remarks>
|
||||
private static Instruction[] SingleInstruction { get; } = new Instruction[1];
|
||||
|
||||
/// <summary>
|
||||
/// The type of IL processing to perform.
|
||||
/// </summary>
|
||||
internal readonly ILProcessingOperationType Type;
|
||||
|
||||
/// <summary>
|
||||
/// The new instructions to add.
|
||||
/// </summary>
|
||||
internal readonly Instruction[] Instructions;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
|
||||
/// </summary>
|
||||
private ILProcessingOperation(ILProcessingOperationType type)
|
||||
{
|
||||
this.Type = type;
|
||||
this.Instructions = null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
|
||||
/// </summary>
|
||||
internal ILProcessingOperation(ILProcessingOperationType type, Instruction instruction)
|
||||
{
|
||||
this.Type = type;
|
||||
SingleInstruction[0] = instruction;
|
||||
this.Instructions = SingleInstruction;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="ILProcessingOperation"/> struct.
|
||||
/// </summary>
|
||||
internal ILProcessingOperation(ILProcessingOperationType type, params Instruction[] instructions)
|
||||
{
|
||||
this.Type = type;
|
||||
this.Instructions = instructions;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
/// <summary>
|
||||
/// The type of IL processing to perform.
|
||||
/// </summary>
|
||||
internal enum ILProcessingOperationType
|
||||
{
|
||||
/// <summary>
|
||||
/// Do not change the current instruction.
|
||||
/// </summary>
|
||||
None,
|
||||
|
||||
/// <summary>
|
||||
/// Insert the specified instructions after the current instruction.
|
||||
/// </summary>
|
||||
InsertAfter,
|
||||
|
||||
/// <summary>
|
||||
/// Insert the specified instructions before the current instruction.
|
||||
/// </summary>
|
||||
InsertBefore,
|
||||
|
||||
/// <summary>
|
||||
/// Replace the current instruction with the specified instructions.
|
||||
/// </summary>
|
||||
Replace
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using Mono.Cecil;
|
||||
using Mono.Cecil.Cil;
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
/// <summary>
|
||||
/// An abstract interface for transforming code using a visitor pattern.
|
||||
/// This is used by the <see cref="AssemblyRewriter"/> to manage multiple different
|
||||
/// transforms in a single pass over an assembly.
|
||||
/// </summary>
|
||||
internal abstract class AssemblyTransform
|
||||
{
|
||||
/// <summary>
|
||||
/// Notify visitor we are starting a new Module.
|
||||
/// </summary>
|
||||
internal abstract void VisitModule(ModuleDefinition module);
|
||||
|
||||
/// <summary>
|
||||
/// Notify visitor we are visiting a new TypeDefinition.
|
||||
/// </summary>
|
||||
internal abstract void VisitType(TypeDefinition typeDef);
|
||||
|
||||
/// <summary>
|
||||
/// Notify visitor we are visiting a field inside the TypeDefinition just given to VisitType.
|
||||
/// </summary>
|
||||
internal abstract void VisitField(FieldDefinition field);
|
||||
|
||||
/// <summary>
|
||||
/// Notify visitor we are visiting a method inside the TypeDefinition just given to VisitType.
|
||||
/// </summary>
|
||||
internal abstract void VisitMethod(MethodDefinition method);
|
||||
|
||||
/// <summary>
|
||||
/// Notify visitor we are visiting a variable declaration inside the MethodDefinition just given to VisitMethod.
|
||||
/// </summary>
|
||||
internal abstract void VisitVariable(VariableDefinition variable);
|
||||
|
||||
/// <summary>
|
||||
/// Visit an IL instruction inside the MethodDefinition body, and get back a possibly transformed instruction.
|
||||
/// </summary>
|
||||
/// <param name="instruction">The instruction to visit.</param>
|
||||
/// <returns>Return the last modified instruction or the same one if it was not changed.</returns>
|
||||
internal abstract Instruction VisitInstruction(Instruction instruction);
|
||||
|
||||
/// <summary>
|
||||
/// Visit an <see cref="ExceptionHandler"/> inside the MethodDefinition. In the case of nested try/catch blocks
|
||||
/// the inner block is visited first before the outer block.
|
||||
/// </summary>
|
||||
internal abstract void VisitExceptionHandler(ExceptionHandler handler);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,484 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using Microsoft.Coyote.IO;
|
||||
using Microsoft.Coyote.Specifications;
|
||||
using Mono.Cecil;
|
||||
using Mono.Cecil.Cil;
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
internal class LockTransform : AssemblyTransform
|
||||
{
|
||||
private ModuleDefinition Module;
|
||||
private TypeDefinition TypeDef;
|
||||
private MethodDefinition Method;
|
||||
private ILProcessor Processor;
|
||||
private List<SyncObjectMapping> Mapping;
|
||||
private const string MonitorClassName = "System.Threading.Monitor";
|
||||
|
||||
/// <summary>
|
||||
/// Maintains a mapping from the "syncobject" used in Monitor::Enter to the
|
||||
/// actual instance of SynchronizedBlock that is now wrapping that sync object.
|
||||
/// </summary>
|
||||
private class SyncObjectMapping
|
||||
{
|
||||
internal FieldDefinition SyncObjectField;
|
||||
internal VariableDefinition SyncObjectVariable;
|
||||
internal VariableDefinition SyncBlockVariable;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override void VisitModule(ModuleDefinition module)
|
||||
{
|
||||
this.Module = module;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override void VisitType(TypeDefinition typeDef)
|
||||
{
|
||||
this.TypeDef = typeDef;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override void VisitField(FieldDefinition field)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override void VisitMethod(MethodDefinition method)
|
||||
{
|
||||
this.Method = method;
|
||||
this.Processor = method.Body.GetILProcessor();
|
||||
this.Mapping = new List<SyncObjectMapping>();
|
||||
}
|
||||
|
||||
internal override void VisitVariable(VariableDefinition variable)
|
||||
{
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override Instruction VisitInstruction(Instruction instruction)
|
||||
{
|
||||
if (this.Method == null)
|
||||
{
|
||||
return instruction;
|
||||
}
|
||||
|
||||
if (instruction.OpCode == OpCodes.Call && instruction.Operand is MethodReference method)
|
||||
{
|
||||
const string PulseMethod = nameof(System.Threading.Monitor.Pulse);
|
||||
const string PulseAllMethod = nameof(System.Threading.Monitor.PulseAll);
|
||||
const string WaitMethod = nameof(System.Threading.Monitor.Wait);
|
||||
|
||||
if (method.DeclaringType.FullName == MonitorClassName && method.Name == PulseMethod)
|
||||
{
|
||||
Debug.WriteLine($"......... [-] call '{method}'");
|
||||
var pulseMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(PulseMethod));
|
||||
this.ReplaceLoadSyncBlock(instruction);
|
||||
var newInstruction = Instruction.Create(OpCodes.Callvirt, pulseMethod);
|
||||
Debug.WriteLine($"......... [+] call '{pulseMethod}'");
|
||||
this.Processor.Replace(instruction, newInstruction);
|
||||
instruction = newInstruction;
|
||||
}
|
||||
else if (method.DeclaringType.FullName == MonitorClassName && method.Name == PulseAllMethod)
|
||||
{
|
||||
Debug.WriteLine($"......... [-] call '{method}'");
|
||||
var pulseAllMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(PulseAllMethod));
|
||||
this.ReplaceLoadSyncBlock(instruction);
|
||||
var newInstruction = Instruction.Create(OpCodes.Callvirt, pulseAllMethod);
|
||||
Debug.WriteLine($"......... [+] call '{pulseAllMethod}'");
|
||||
this.Processor.Replace(instruction, newInstruction);
|
||||
instruction = newInstruction;
|
||||
}
|
||||
else if (method.DeclaringType.FullName == MonitorClassName && method.Name == "Wait")
|
||||
{
|
||||
// public static bool Wait(object obj);
|
||||
// public static bool Wait(object obj, int millisecondsTimeout);
|
||||
// public static bool Wait(object obj, int millisecondsTimeout, bool exitContext);
|
||||
// public static bool Wait(object obj, TimeSpan timeout);
|
||||
// public static bool Wait(object obj, TimeSpan timeout, bool exitContext);
|
||||
if (method.Parameters.Count == 1)
|
||||
{
|
||||
Debug.WriteLine($"......... [-] call '{method}'");
|
||||
var waitMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod(WaitMethod, Array.Empty<Type>()));
|
||||
this.ReplaceLoadSyncBlock(instruction);
|
||||
var newInstruction = Instruction.Create(OpCodes.Callvirt, waitMethod);
|
||||
Debug.WriteLine($"......... [+] call '{waitMethod}'");
|
||||
this.Processor.Replace(instruction, newInstruction);
|
||||
instruction = newInstruction;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
internal override void VisitExceptionHandler(ExceptionHandler handler)
|
||||
{
|
||||
// a C# lock statement uses try/finally block, where Monitor.Enter is at the beginning of the Try
|
||||
// and Monitor.Exit is in the finally block. If the Monitor does not follow this pattern then it
|
||||
// is probably something else.
|
||||
|
||||
// The C# lock statement only has finally block, so there should be no CatchType.
|
||||
if (handler.CatchType == null && MatchLockEnter(handler.TryStart, handler.TryEnd))
|
||||
{
|
||||
if (MatchLockExit(handler.HandlerStart, handler.HandlerEnd))
|
||||
{
|
||||
this.RewriteLock(handler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Resolve the FieldDefinition referenced in the given load instruction if any.
|
||||
/// </summary>
|
||||
/// <returns>A FieldDefinition or null.</returns>
|
||||
private static FieldDefinition GetFieldDefinition(Instruction loadInstruction)
|
||||
{
|
||||
FieldDefinition fd = null;
|
||||
if (loadInstruction.Operand is FieldReference fr)
|
||||
{
|
||||
fd = fr.Resolve();
|
||||
}
|
||||
else if (loadInstruction.Operand is FieldDefinition fdef)
|
||||
{
|
||||
fd = fdef;
|
||||
}
|
||||
|
||||
return fd;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Rewrite "ldarg.0, ldfld syncobject" with "ldloc.n" to load the local SynchronizedBlock instead
|
||||
/// of the sync object, because the SynchronizedBlock methods (Pulse, Wait, etc) are instance methods
|
||||
/// not static methods and they do not take the syncobject as an argument because the SynchronizedBlock
|
||||
/// stores the sync object so it doesn't need it here.
|
||||
/// </summary>
|
||||
internal void ReplaceLoadSyncBlock(Instruction methodCall)
|
||||
{
|
||||
Instruction loadSyncObject = methodCall.Previous;
|
||||
if (loadSyncObject.OpCode == OpCodes.Ldfld)
|
||||
{
|
||||
var fd = GetFieldDefinition(loadSyncObject);
|
||||
if (fd != null)
|
||||
{
|
||||
var syncBlock = this.FindSyncBlockVar(fd);
|
||||
if (syncBlock != null)
|
||||
{
|
||||
Instruction loadThis = loadSyncObject.Previous;
|
||||
if (loadThis.OpCode == OpCodes.Ldarg_0)
|
||||
{
|
||||
this.Processor.Remove(loadThis);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Expecting load 'this' instruction here...");
|
||||
}
|
||||
|
||||
this.Processor.Replace(loadSyncObject, CreateLoadOp(syncBlock));
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: this code only works if the Pulse, PulseAll, or Wait method is called inside
|
||||
// the same method containing the C# lock statement. This is normally the case but if
|
||||
// someone decided to get clever and call a helper method and that helper method calls
|
||||
// Wait then I have a problem because the SynchronizedBlock local variable will not be
|
||||
// available in that new method...
|
||||
throw new NotImplementedException(string.Format(
|
||||
"Cannot find the matching SynchronizedBlock for the synchronizing object '{0}' in {1}",
|
||||
loadSyncObject, this.Method.FullName));
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (IsLoadOp(loadSyncObject.OpCode))
|
||||
{
|
||||
// TODO: can this happen? I haven't seen it yet...
|
||||
throw new NotImplementedException("Monitor method called using local variable instead of ldfld");
|
||||
}
|
||||
}
|
||||
|
||||
internal static bool MatchLockEnter(Instruction start, Instruction end)
|
||||
{
|
||||
if (CountInstructions(start, end) >= 3)
|
||||
{
|
||||
Instruction a = start;
|
||||
Instruction b = a.Next;
|
||||
Instruction c = b.Next;
|
||||
|
||||
if (IsLoadOp(a.OpCode) && // C# always creates local variable for the sync object
|
||||
b.OpCode == OpCodes.Ldloca_S && // the LockTaken boolean flag should be local
|
||||
c.OpCode == OpCodes.Call && c.Operand is MethodReference method && method.DeclaringType.FullName == MonitorClassName && method.Name == "Enter")
|
||||
{
|
||||
// looks like a C# lock then!
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
internal static bool MatchLockExit(Instruction start, Instruction end)
|
||||
{
|
||||
if (CountInstructions(start, end) >= 4)
|
||||
{
|
||||
Instruction a = start;
|
||||
Instruction b = a.Next;
|
||||
Instruction c = b.Next;
|
||||
Instruction d = c.Next;
|
||||
|
||||
if (IsLoadOp(a.OpCode) && // checking the LockTaken boolean flag
|
||||
b.OpCode == OpCodes.Brfalse_S && // no exit if we didn't get the lock!
|
||||
IsLoadOp(c.OpCode) && // loading the sync object
|
||||
d.OpCode == OpCodes.Call && d.Operand is MethodReference method && method.DeclaringType.FullName == MonitorClassName && method.Name == "Exit")
|
||||
{
|
||||
// looks like a C# lock release then!
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replaces Monitor.Enter with SynchronizedBlock.Lock and Monitor.Exit with SynchronizedBlock.Dispose().
|
||||
/// </summary>
|
||||
private void RewriteLock(ExceptionHandler handler)
|
||||
{
|
||||
// Then we can replace the LockTaken boolean local variable with a new SynchronizedBlock object and all this
|
||||
// happens outside the try statement, so InsertBefore the start of the try block.
|
||||
Instruction a = handler.TryStart; // local variable for the sync object
|
||||
Instruction b = a.Next; // the LockTaken boolean flag should be local
|
||||
Instruction c = b.Next; // the Monitor.Enter call.
|
||||
|
||||
var syncObjectVar = this.GetLocalVariable(a.OpCode, a.Operand);
|
||||
var lockIndexVar = this.GetLocalVariable(b.OpCode, b.Operand);
|
||||
|
||||
// creates: using(var m = SynchronizedBlock.Lock(syncObject)) { ...
|
||||
|
||||
Debug.WriteLine($"......... [-] variable '{lockIndexVar.VariableType}'");
|
||||
Debug.WriteLine($"......... [-] call '{c}'");
|
||||
lockIndexVar.VariableType = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock)); // re-purpose this variable.
|
||||
this.InitializeNull(lockIndexVar);
|
||||
Debug.WriteLine($"......... [+] variable '{lockIndexVar.VariableType}'");
|
||||
|
||||
Instruction load_sync_object = CreateLoadOp(syncObjectVar);
|
||||
this.Processor.InsertAfter(a.Previous, load_sync_object);
|
||||
|
||||
var lockMethod = this.Module.ImportReference(typeof(Microsoft.Coyote.Tasks.SynchronizedBlock).GetMethod("Lock"));
|
||||
|
||||
Instruction create_sync_block = Instruction.Create(OpCodes.Call, lockMethod);
|
||||
|
||||
Debug.WriteLine($"......... [+] call '{create_sync_block}'");
|
||||
this.Processor.InsertAfter(load_sync_object, create_sync_block);
|
||||
|
||||
Instruction stfield = CreateStoreOp(lockIndexVar);
|
||||
this.Processor.InsertAfter(create_sync_block, stfield);
|
||||
|
||||
this.AddMapping(syncObjectVar, lockIndexVar);
|
||||
|
||||
// and now we can remove the first 3 instructions of the try block that were calling Monitor.Enter.
|
||||
|
||||
this.Processor.Remove(a);
|
||||
this.Processor.Remove(b);
|
||||
this.Processor.Remove(c);
|
||||
|
||||
// fix the finally block.
|
||||
Instruction d = handler.HandlerStart; // ldloc LockTaken becomes ldloc SynchronizedBlock object
|
||||
Instruction e = d.Next; // brfalse, this remains the same
|
||||
Instruction f = e.Next; // ldloc sync object, becomes ldloc syncblock variable.
|
||||
Instruction g = f.Next; // call Monitor.Exit.
|
||||
|
||||
// Create: m.Dispose()
|
||||
Instruction load_sync_block = CreateLoadOp(lockIndexVar); // remember lockIndexVar has been re-purposed to store the SynchronizedBlock object
|
||||
this.Processor.Replace(f, load_sync_block); // we need to load the object we are calling Dispose on. not the sync object
|
||||
var disposeMethod = this.Module.ImportReference(typeof(System.IDisposable).GetMethod("Dispose"));
|
||||
Instruction dispose = Instruction.Create(OpCodes.Callvirt, disposeMethod);
|
||||
|
||||
Debug.WriteLine($"......... [-] call '{g}'");
|
||||
Debug.WriteLine($"......... [+] call '{dispose}'");
|
||||
this.Processor.Replace(g, dispose);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Make sure a variable is initialized to null (when variable was a bool it was initialized to integer 0 instead).
|
||||
/// </summary>
|
||||
private void InitializeNull(VariableDefinition v)
|
||||
{
|
||||
Instruction instruction = this.Method.Body.Instructions.FirstOrDefault();
|
||||
while (instruction != null)
|
||||
{
|
||||
var op = instruction.OpCode;
|
||||
if (IsStoreOp(op))
|
||||
{
|
||||
VariableDefinition v2 = this.GetLocalVariable(op, instruction.Operand);
|
||||
if (v2.Index == v.Index)
|
||||
{
|
||||
// found the store operation for this local variable, so make sure it is initialized to null
|
||||
if (instruction.Previous != null && instruction.Previous.OpCode != OpCodes.Ldnull)
|
||||
{
|
||||
this.Processor.Replace(instruction.Previous, Instruction.Create(OpCodes.Ldnull));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instruction = instruction.Next;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Remember the connection between syncObject and SynchronizedBlock and find which
|
||||
/// FieldDefinition stores the syncObject.
|
||||
/// </summary>
|
||||
private void AddMapping(VariableDefinition syncObjectVar, VariableDefinition syncBlockVar)
|
||||
{
|
||||
// find the ldarg, ldfld, stloc.x for the syncObject so we know what it's FieldDefinition is.
|
||||
var mapping = new SyncObjectMapping()
|
||||
{
|
||||
SyncObjectVariable = syncObjectVar,
|
||||
SyncBlockVariable = syncBlockVar
|
||||
};
|
||||
|
||||
Instruction instruction = this.Method.Body.Instructions.FirstOrDefault();
|
||||
while (instruction != null)
|
||||
{
|
||||
var op = instruction.OpCode;
|
||||
if (IsStoreOp(op))
|
||||
{
|
||||
VariableDefinition v = this.GetLocalVariable(op, instruction.Operand);
|
||||
if (v.Index == syncObjectVar.Index)
|
||||
{
|
||||
// found the store operation for this local variable!
|
||||
if (instruction.Previous != null && instruction.Previous.OpCode == OpCodes.Ldfld)
|
||||
{
|
||||
var fd = GetFieldDefinition(instruction.Previous);
|
||||
if (fd != null)
|
||||
{
|
||||
mapping.SyncObjectField = fd;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instruction = instruction.Next;
|
||||
}
|
||||
|
||||
if (mapping.SyncObjectField == null)
|
||||
{
|
||||
// TODO: hmmm, perhaps the location of the sync object is more complicated...
|
||||
}
|
||||
|
||||
this.Mapping.Add(mapping);
|
||||
}
|
||||
|
||||
private VariableDefinition FindSyncBlockVar(FieldDefinition fd)
|
||||
{
|
||||
foreach (var item in this.Mapping)
|
||||
{
|
||||
if (item.SyncObjectField == fd)
|
||||
{
|
||||
return item.SyncBlockVariable;
|
||||
}
|
||||
}
|
||||
|
||||
Debug.WriteLine("### Cannot find SynchronizedBlock created for this Synchronizing Object");
|
||||
return null;
|
||||
}
|
||||
|
||||
internal static bool IsLoadOp(OpCode op)
|
||||
{
|
||||
return op == OpCodes.Ldloc_S || op == OpCodes.Ldloc_0 || op == OpCodes.Ldloc_1 || op == OpCodes.Ldloc_2 || op == OpCodes.Ldloc_3;
|
||||
}
|
||||
|
||||
internal static bool IsStoreOp(OpCode op)
|
||||
{
|
||||
return op == OpCodes.Stloc_S || op == OpCodes.Stloc_0 || op == OpCodes.Stloc_1 || op == OpCodes.Stloc_2 || op == OpCodes.Stloc_3;
|
||||
}
|
||||
|
||||
internal VariableDefinition GetLocalVariable(OpCode op, object operand)
|
||||
{
|
||||
if ((op == OpCodes.Ldloc || op == OpCodes.Ldloc_S || op == OpCodes.Ldloca_S || op == OpCodes.Stloc_S || op == OpCodes.Stloc) && operand is VariableDefinition vdef)
|
||||
{
|
||||
return vdef;
|
||||
}
|
||||
|
||||
if (op == OpCodes.Ldloc_0 || op == OpCodes.Stloc_0)
|
||||
{
|
||||
return this.Method.Body.Variables[0];
|
||||
}
|
||||
|
||||
if (op == OpCodes.Ldloc_1 || op == OpCodes.Stloc_1)
|
||||
{
|
||||
return this.Method.Body.Variables[1];
|
||||
}
|
||||
|
||||
if (op == OpCodes.Ldloc_2 || op == OpCodes.Stloc_2)
|
||||
{
|
||||
return this.Method.Body.Variables[2];
|
||||
}
|
||||
|
||||
if (op == OpCodes.Ldloc_3 || op == OpCodes.Stloc_3)
|
||||
{
|
||||
return this.Method.Body.Variables[3];
|
||||
}
|
||||
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
|
||||
private static Instruction CreateLoadOp(VariableDefinition var)
|
||||
{
|
||||
switch (var.Index)
|
||||
{
|
||||
case 0:
|
||||
return Instruction.Create(OpCodes.Ldloc_0);
|
||||
case 1:
|
||||
return Instruction.Create(OpCodes.Ldloc_1);
|
||||
case 2:
|
||||
return Instruction.Create(OpCodes.Ldloc_2);
|
||||
case 3:
|
||||
return Instruction.Create(OpCodes.Ldloc_3);
|
||||
default:
|
||||
return Instruction.Create(OpCodes.Ldloc, var);
|
||||
}
|
||||
}
|
||||
|
||||
internal static Instruction CreateStoreOp(VariableDefinition var)
|
||||
{
|
||||
switch (var.Index)
|
||||
{
|
||||
case 0:
|
||||
return Instruction.Create(OpCodes.Stloc_0);
|
||||
case 1:
|
||||
return Instruction.Create(OpCodes.Stloc_1);
|
||||
case 2:
|
||||
return Instruction.Create(OpCodes.Stloc_2);
|
||||
case 3:
|
||||
return Instruction.Create(OpCodes.Stloc_3);
|
||||
default:
|
||||
return Instruction.Create(OpCodes.Stloc_S, var);
|
||||
}
|
||||
}
|
||||
|
||||
internal static int CountInstructions(Instruction start, Instruction end)
|
||||
{
|
||||
int count = 0;
|
||||
while (start != end)
|
||||
{
|
||||
count++;
|
||||
start = start.Next;
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,541 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Coyote.IO;
|
||||
using Mono.Cecil;
|
||||
using Mono.Cecil.Cil;
|
||||
using ControlledTasks = Microsoft.Coyote.SystematicTesting.Interception;
|
||||
using CoyoteTasks = Microsoft.Coyote.Tasks;
|
||||
using SystemCompiler = System.Runtime.CompilerServices;
|
||||
using SystemTasks = System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.Coyote.Rewriting
|
||||
{
|
||||
internal class TaskTransform : AssemblyTransform
|
||||
{
|
||||
/// <summary>
|
||||
/// Cache from <see cref="SystemTasks"/> type names to types being replaced
|
||||
/// in the module that is currently being rewritten.
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, TypeReference> TaskTypeCache;
|
||||
|
||||
/// <summary>
|
||||
/// Cache from <see cref="SystemCompiler"/> type names to types being replaced
|
||||
/// in the module that is currently being rewritten.
|
||||
/// </summary>
|
||||
private readonly Dictionary<string, TypeReference> CompilerTypeCache;
|
||||
|
||||
/// <summary>
|
||||
/// The current module being transformed.
|
||||
/// </summary>
|
||||
private ModuleDefinition Module;
|
||||
|
||||
/// <summary>
|
||||
/// The current type being transformed.
|
||||
/// </summary>
|
||||
private TypeDefinition TypeDef;
|
||||
|
||||
/// <summary>
|
||||
/// The current method being transformed.
|
||||
/// </summary>
|
||||
private MethodDefinition Method;
|
||||
|
||||
/// <summary>
|
||||
/// A helper class for editing method body.
|
||||
/// </summary>
|
||||
private ILProcessor Processor;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TaskTransform"/> class.
|
||||
/// </summary>
|
||||
internal TaskTransform()
|
||||
{
|
||||
this.TaskTypeCache = new Dictionary<string, TypeReference>();
|
||||
this.CompilerTypeCache = new Dictionary<string, TypeReference>();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cached generic task type name prefix.
|
||||
/// </summary>
|
||||
private const string GenericTaskTypeNamePrefix = "Task`";
|
||||
|
||||
internal override void VisitModule(ModuleDefinition module)
|
||||
{
|
||||
this.Module = module;
|
||||
this.TaskTypeCache.Clear();
|
||||
this.CompilerTypeCache.Clear();
|
||||
}
|
||||
|
||||
internal override void VisitType(TypeDefinition typeDef)
|
||||
{
|
||||
this.TypeDef = typeDef;
|
||||
this.Method = null;
|
||||
this.Processor = null;
|
||||
}
|
||||
|
||||
internal override void VisitField(FieldDefinition field)
|
||||
{
|
||||
if (this.TryGetCompilerTypeReplacement(field.FieldType, this.TypeDef, out TypeReference newFieldType))
|
||||
{
|
||||
Debug.WriteLine($"......... [-] field '{field}'");
|
||||
field.FieldType = newFieldType;
|
||||
Debug.WriteLine($"......... [+] field '{field}'");
|
||||
}
|
||||
}
|
||||
|
||||
internal override void VisitMethod(MethodDefinition method)
|
||||
{
|
||||
this.Method = null;
|
||||
|
||||
// Only non-abstract method bodies can be rewritten.
|
||||
if (!method.IsAbstract)
|
||||
{
|
||||
this.Method = method;
|
||||
this.Processor = method.Body.GetILProcessor();
|
||||
}
|
||||
|
||||
// bugbug: what if this is an override of an inherited virtual method? For example, what if there
|
||||
// is an external base class that is a Task like type that implements a virtual GetAwaiter() that
|
||||
// is overridden by this method?
|
||||
if (this.TryGetCompilerTypeReplacement(method.ReturnType, method, out TypeReference newReturnType))
|
||||
{
|
||||
Debug.WriteLine($"........... [-] return type '{method.ReturnType}'");
|
||||
method.ReturnType = newReturnType;
|
||||
Debug.WriteLine($"........... [+] return type '{method.ReturnType}'");
|
||||
}
|
||||
}
|
||||
|
||||
internal override void VisitExceptionHandler(ExceptionHandler handler)
|
||||
{
|
||||
}
|
||||
|
||||
internal override void VisitVariable(VariableDefinition variable)
|
||||
{
|
||||
if (this.Method == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.TryGetCompilerTypeReplacement(variable.VariableType, this.Method, out TypeReference newVariableType))
|
||||
{
|
||||
Debug.WriteLine($"........... [-] variable '{variable.VariableType}'");
|
||||
variable.VariableType = newVariableType;
|
||||
Debug.WriteLine($"........... [+] variable '{variable.VariableType}'");
|
||||
}
|
||||
}
|
||||
|
||||
internal override Instruction VisitInstruction(Instruction instruction)
|
||||
{
|
||||
if (this.Method == null)
|
||||
{
|
||||
return instruction;
|
||||
}
|
||||
|
||||
// TODO: what about ldsfld, for static fields?
|
||||
if (instruction.OpCode == OpCodes.Stfld || instruction.OpCode == OpCodes.Ldfld || instruction.OpCode == OpCodes.Ldflda)
|
||||
{
|
||||
if (instruction.Operand is FieldDefinition fd &&
|
||||
this.TryGetCompilerTypeReplacement(fd.FieldType, this.Method, out TypeReference newFieldType))
|
||||
{
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
fd.FieldType = newFieldType;
|
||||
Debug.WriteLine($"........... [+] {instruction}");
|
||||
}
|
||||
else if (instruction.Operand is FieldReference fr &&
|
||||
this.TryGetCompilerTypeReplacement(fr.FieldType, this.Method, out newFieldType))
|
||||
{
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
fr.FieldType = newFieldType;
|
||||
Debug.WriteLine($"........... [+] {instruction}");
|
||||
}
|
||||
}
|
||||
else if (instruction.OpCode == OpCodes.Initobj)
|
||||
{
|
||||
instruction = this.VisitInitobjProcessingOperation(instruction, this.Method);
|
||||
}
|
||||
else if ((instruction.OpCode == OpCodes.Call || instruction.OpCode == OpCodes.Callvirt) &&
|
||||
instruction.Operand is MethodReference methodReference)
|
||||
{
|
||||
instruction = this.VisitCallProcessingOperation(instruction, methodReference, this.Method);
|
||||
}
|
||||
|
||||
// return the last modified instruction or the same one if it was not changed.
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Transform the specified non-generic <see cref="OpCodes.Call"/> or <see cref="OpCodes.Callvirt"/> instruction.
|
||||
/// </summary>
|
||||
/// <returns>The unmodified instruction or the last newly inserted instruction.</returns>
|
||||
private Instruction VisitCallProcessingOperation(Instruction instruction, MethodReference method, IGenericParameterProvider provider)
|
||||
{
|
||||
TypeReference newType = null;
|
||||
var opCode = instruction.OpCode;
|
||||
|
||||
if (IsSystemTaskType(method.DeclaringType))
|
||||
{
|
||||
// Special rules apply for methods under the Task namespace.
|
||||
if (method.Name == nameof(SystemTasks.Task.Run) ||
|
||||
method.Name == nameof(SystemTasks.Task.Delay) ||
|
||||
method.Name == nameof(SystemTasks.Task.WhenAll) ||
|
||||
method.Name == nameof(SystemTasks.Task.WhenAny) ||
|
||||
method.Name == nameof(SystemTasks.Task.Yield) ||
|
||||
method.Name == nameof(SystemTasks.Task.GetAwaiter))
|
||||
{
|
||||
newType = this.GetTaskTypeReplacement(method.DeclaringType);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
newType = this.GetCompilerTypeReplacement(method.DeclaringType, provider);
|
||||
if (newType == method.DeclaringType)
|
||||
{
|
||||
newType = null;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check if "new type is null" check is required.
|
||||
if (newType is null || !this.TryGetReplacementMethod(newType, method, out MethodReference newMethod))
|
||||
{
|
||||
// There is nothing to rewrite, return with the none operation.
|
||||
return instruction;
|
||||
}
|
||||
|
||||
OpCode newOpCode = opCode;
|
||||
if (newMethod.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter))
|
||||
{
|
||||
// The OpCode must change for the GetAwaiter method.
|
||||
newOpCode = OpCodes.Call;
|
||||
}
|
||||
|
||||
// Create and return the new instruction.
|
||||
Instruction newInstruction = Instruction.Create(newOpCode, newMethod);
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
this.Processor.Replace(instruction, newInstruction);
|
||||
Debug.WriteLine($"........... [+] {newInstruction}");
|
||||
return newInstruction;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Transform the specified <see cref="OpCodes.Initobj"/> instruction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Return the unmodified instruction, or the newly replaced instruction.
|
||||
/// </remarks>
|
||||
private Instruction VisitInitobjProcessingOperation(Instruction instruction, MethodDefinition method)
|
||||
{
|
||||
TypeReference type = instruction.Operand as TypeReference;
|
||||
if (this.TryGetCompilerTypeReplacement(type, method, out TypeReference newType))
|
||||
{
|
||||
var newInstruction = Instruction.Create(instruction.OpCode, newType);
|
||||
Debug.WriteLine($"........... [-] {instruction}");
|
||||
this.Processor.Replace(instruction, newInstruction);
|
||||
Debug.WriteLine($"........... [+] {newInstruction}");
|
||||
instruction = newInstruction;
|
||||
}
|
||||
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a method from the specified type that can replace the original method, if any.
|
||||
/// </summary>
|
||||
private bool TryGetReplacementMethod(TypeReference replacementType, MethodReference originalMethod, out MethodReference result)
|
||||
{
|
||||
result = null;
|
||||
|
||||
TypeDefinition replacementTypeDef = replacementType.Resolve();
|
||||
if (replacementTypeDef == null)
|
||||
{
|
||||
throw new Exception(string.Format("Error resolving type: {0}", replacementType.FullName));
|
||||
}
|
||||
|
||||
MethodDefinition original = originalMethod.Resolve();
|
||||
|
||||
bool isGetControlledAwaiter = false;
|
||||
foreach (var replacement in replacementTypeDef.Methods)
|
||||
{
|
||||
// TODO: make sure all necessary checks are in place!
|
||||
if (!(!replacement.IsConstructor &&
|
||||
replacement.Name == original.Name &&
|
||||
replacement.ReturnType.IsGenericInstance == original.ReturnType.IsGenericInstance &&
|
||||
replacement.IsPublic == original.IsPublic &&
|
||||
replacement.IsPrivate == original.IsPrivate &&
|
||||
replacement.IsAssembly == original.IsAssembly &&
|
||||
replacement.IsFamilyAndAssembly == original.IsFamilyAndAssembly))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
isGetControlledAwaiter = replacement.DeclaringType.Namespace == KnownNamespaces.ControlledTasksName &&
|
||||
replacement.Name == nameof(ControlledTasks.ControlledTask.GetAwaiter);
|
||||
if (!isGetControlledAwaiter)
|
||||
{
|
||||
// Only check that the parameters match for non-controlled awaiter methods.
|
||||
// That is because we do special rewriting for this method.
|
||||
if (!CheckMethodParametersMatch(replacement, original))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Import the method in the module that is being rewritten.
|
||||
result = originalMethod.Module.ImportReference(replacement);
|
||||
break;
|
||||
}
|
||||
|
||||
if (result is null)
|
||||
{
|
||||
// TODO: raise an error.
|
||||
return false;
|
||||
}
|
||||
|
||||
result.DeclaringType = replacementType;
|
||||
if (originalMethod is GenericInstanceMethod genericMethod)
|
||||
{
|
||||
var newGenericMethod = new GenericInstanceMethod(result);
|
||||
|
||||
// The method is generic, so populate it with generic argument types and parameters.
|
||||
foreach (var arg in genericMethod.GenericArguments)
|
||||
{
|
||||
TypeReference newArgumentType = this.GetCompilerTypeReplacement(arg, newGenericMethod);
|
||||
newGenericMethod.GenericArguments.Add(newArgumentType);
|
||||
}
|
||||
|
||||
result = newGenericMethod;
|
||||
}
|
||||
else if (isGetControlledAwaiter && originalMethod.DeclaringType is GenericInstanceType genericType)
|
||||
{
|
||||
// Special processing applies in this case, because we are converting the `Task<T>.GetAwaiter`
|
||||
// non-generic instance method to the `ControlledTask.GetAwaiter<T>` generic static method.
|
||||
var newGenericMethod = new GenericInstanceMethod(result);
|
||||
|
||||
// There is only a single argument type, which must be added to ma.
|
||||
newGenericMethod.GenericArguments.Add(this.GetCompilerTypeReplacement(genericType.GenericArguments[0], newGenericMethod));
|
||||
|
||||
// The single generic argument type in the task parameter must be rewritten to the same
|
||||
// generic argument type as the one in the return type of `GetAwaiter<T>`.
|
||||
var parameterType = newGenericMethod.Parameters[0].ParameterType as GenericInstanceType;
|
||||
parameterType.GenericArguments[0] = (newGenericMethod.ReturnType as GenericInstanceType).GenericArguments[0];
|
||||
|
||||
result = newGenericMethod;
|
||||
}
|
||||
|
||||
// Rewrite the parameters of the method, if any.
|
||||
for (int idx = 0; idx < originalMethod.Parameters.Count; idx++)
|
||||
{
|
||||
ParameterDefinition parameter = originalMethod.Parameters[idx];
|
||||
TypeReference newParameterType = this.GetCompilerTypeReplacement(parameter.ParameterType, result);
|
||||
ParameterDefinition newParameter = new ParameterDefinition(parameter.Name, parameter.Attributes, newParameterType);
|
||||
result.Parameters[idx] = newParameter;
|
||||
}
|
||||
|
||||
if (result.ReturnType.Namespace != KnownNamespaces.ControlledTasksName)
|
||||
{
|
||||
result.ReturnType = this.GetCompilerTypeReplacement(originalMethod.ReturnType, result);
|
||||
}
|
||||
|
||||
return originalMethod.FullName != result.FullName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if the parameters of the two methods match.
|
||||
/// </summary>
|
||||
private static bool CheckMethodParametersMatch(MethodDefinition left, MethodDefinition right)
|
||||
{
|
||||
if (left.Parameters.Count != right.Parameters.Count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < right.Parameters.Count; idx++)
|
||||
{
|
||||
var originalParam = right.Parameters[0];
|
||||
var replacementParam = left.Parameters[0];
|
||||
// TODO: make sure all necessary checks are in place!
|
||||
if ((replacementParam.ParameterType.FullName != originalParam.ParameterType.FullName) ||
|
||||
(replacementParam.Name != originalParam.Name) ||
|
||||
(replacementParam.IsIn && !originalParam.IsIn) ||
|
||||
(replacementParam.IsOut && !originalParam.IsOut))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified <see cref="SystemTasks"/> type, else null.
|
||||
/// </summary>
|
||||
private TypeReference GetTaskTypeReplacement(TypeReference type)
|
||||
{
|
||||
TypeReference result;
|
||||
|
||||
string fullName = type.FullName;
|
||||
if (this.TaskTypeCache.ContainsKey(fullName))
|
||||
{
|
||||
result = this.TaskTypeCache[fullName];
|
||||
if (result.Module != type.Module)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.TaskTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result = type.Module.ImportReference(typeof(ControlledTasks.ControlledTask));
|
||||
this.TaskTypeCache[fullName] = result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tries to return the replacement type for the specified <see cref="SystemCompiler"/> type, if such a type exists.
|
||||
/// </summary>
|
||||
private bool TryGetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider, out TypeReference result)
|
||||
{
|
||||
result = this.GetCompilerTypeReplacement(type, provider);
|
||||
return result.FullName != type.FullName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified <see cref="SystemCompiler"/> type, else null.
|
||||
/// </summary>
|
||||
private TypeReference GetCompilerTypeReplacement(TypeReference type, IGenericParameterProvider provider)
|
||||
{
|
||||
TypeReference result = type;
|
||||
|
||||
string fullName = type.FullName;
|
||||
if (this.CompilerTypeCache.ContainsKey(fullName))
|
||||
{
|
||||
result = this.CompilerTypeCache[fullName];
|
||||
if (result.Module != type.Module)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.CompilerTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else if (type.IsGenericInstance &&
|
||||
(type.Name == KnownSystemTypes.GenericAsyncTaskMethodBuilderName ||
|
||||
type.Name == KnownSystemTypes.GenericTaskAwaiterName))
|
||||
{
|
||||
result = this.GetGenericTypeReplacement(type as GenericInstanceType, provider);
|
||||
if (result.FullName != fullName)
|
||||
{
|
||||
result = type.Module.ImportReference(result);
|
||||
this.CompilerTypeCache[fullName] = result;
|
||||
}
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.AsyncTaskMethodBuilderFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.GenericAsyncTaskMethodBuilderFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(ControlledTasks.AsyncTaskMethodBuilder<>));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.TaskAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.GenericTaskAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.TaskAwaiter<>));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.YieldAwaitableFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable));
|
||||
}
|
||||
else if (fullName == KnownSystemTypes.YieldAwaiterFullName)
|
||||
{
|
||||
result = this.ImportCompilerTypeReplacement(type, typeof(CoyoteTasks.YieldAwaitable.YieldAwaiter));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Import the replacement coyote type and cache it in the CompilerTypeCache and make sure the type can be
|
||||
/// fully resolved.
|
||||
/// </summary>
|
||||
private TypeReference ImportCompilerTypeReplacement(TypeReference originalType, Type coyoteType)
|
||||
{
|
||||
var result = originalType.Module.ImportReference(coyoteType);
|
||||
this.CompilerTypeCache[originalType.FullName] = result;
|
||||
|
||||
TypeDefinition coyoteTypeDef = result.Resolve();
|
||||
if (coyoteTypeDef == null)
|
||||
{
|
||||
throw new Exception(string.Format("Unexpected error resolving type: {0}", coyoteType.FullName));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the replacement type for the specified generic type, else null.
|
||||
/// </summary>
|
||||
private GenericInstanceType GetGenericTypeReplacement(GenericInstanceType type, IGenericParameterProvider provider)
|
||||
{
|
||||
GenericInstanceType result = type;
|
||||
TypeReference genericType = this.GetCompilerTypeReplacement(type.ElementType, null);
|
||||
if (type.ElementType.FullName != genericType.FullName)
|
||||
{
|
||||
// The generic type must be rewritten.
|
||||
result = new GenericInstanceType(type.Module.ImportReference(genericType));
|
||||
foreach (var arg in type.GenericArguments)
|
||||
{
|
||||
TypeReference newArgumentType;
|
||||
if (arg.IsGenericParameter)
|
||||
{
|
||||
GenericParameter parameter = new GenericParameter(arg.Name, provider ?? result);
|
||||
result.GenericParameters.Add(parameter);
|
||||
newArgumentType = parameter;
|
||||
}
|
||||
else
|
||||
{
|
||||
newArgumentType = this.GetCompilerTypeReplacement(arg, provider);
|
||||
}
|
||||
|
||||
result.GenericArguments.Add(newArgumentType);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if the specified type is the <see cref="SystemTasks.Task"/> type.
|
||||
/// </summary>
|
||||
private static bool IsSystemTaskType(TypeReference type) => type.Namespace == KnownNamespaces.SystemTasksName &&
|
||||
(type.Name == typeof(SystemTasks.Task).Name || type.Name.StartsWith(GenericTaskTypeNamePrefix));
|
||||
|
||||
/// <summary>
|
||||
/// Cache of known <see cref="SystemCompiler"/> type names.
|
||||
/// </summary>
|
||||
private static class KnownSystemTypes
|
||||
{
|
||||
internal static string AsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder).FullName;
|
||||
internal static string GenericAsyncTaskMethodBuilderName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).Name;
|
||||
internal static string GenericAsyncTaskMethodBuilderFullName { get; } = typeof(SystemCompiler.AsyncTaskMethodBuilder<>).FullName;
|
||||
internal static string TaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter).FullName;
|
||||
internal static string GenericTaskAwaiterName { get; } = typeof(SystemCompiler.TaskAwaiter<>).Name;
|
||||
internal static string GenericTaskAwaiterFullName { get; } = typeof(SystemCompiler.TaskAwaiter<>).FullName;
|
||||
internal static string YieldAwaitableFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName;
|
||||
internal static string YieldAwaiterFullName { get; } = typeof(SystemCompiler.YieldAwaitable).FullName + "/YieldAwaiter";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cache of known namespace names.
|
||||
/// </summary>
|
||||
private static class KnownNamespaces
|
||||
{
|
||||
internal static string ControlledTasksName { get; } = typeof(ControlledTasks.ControlledTask).Namespace;
|
||||
internal static string SystemTasksName { get; } = typeof(SystemTasks.Task).Namespace;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,4 +16,7 @@
|
|||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\Source\Core\Core.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Folder Include="Transforms\" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Coyote.Specifications;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
using Monitor = System.Threading.Monitor;
|
||||
|
||||
namespace Microsoft.Coyote.BinaryRewriting.Tests.Tasks.Locks
|
||||
{
|
||||
public class CSharpLockTests : BaseProductionTest
|
||||
{
|
||||
private readonly object SyncObject1 = new object();
|
||||
private string Value;
|
||||
|
||||
public CSharpLockTests(ITestOutputHelper output)
|
||||
: base(output)
|
||||
{
|
||||
}
|
||||
|
||||
[Fact(Timeout = 5000)]
|
||||
public void TestSimpleLock()
|
||||
{
|
||||
this.Test(() =>
|
||||
{
|
||||
lock (this.SyncObject1)
|
||||
{
|
||||
this.Value = "1";
|
||||
this.TestReentrancy();
|
||||
}
|
||||
|
||||
var expected = "2";
|
||||
Specification.Assert(this.Value == expected, "Value is {0} instead of {1}.", this.Value, expected);
|
||||
});
|
||||
}
|
||||
|
||||
private void TestReentrancy()
|
||||
{
|
||||
lock (this.SyncObject1)
|
||||
{
|
||||
this.Value = "2";
|
||||
}
|
||||
}
|
||||
|
||||
[Fact(Timeout = 5000)]
|
||||
public void TestWaitPulse()
|
||||
{
|
||||
this.Test(async () =>
|
||||
{
|
||||
var t1 = Task.Run(this.TakeTask);
|
||||
var t2 = Task.Run(this.PutTask);
|
||||
|
||||
await Task.WhenAll(t1, t2);
|
||||
|
||||
var expected = "taken";
|
||||
Specification.Assert(this.Value == expected, "Value is {0} instead of {1}.", this.Value, expected);
|
||||
});
|
||||
}
|
||||
|
||||
private void TakeTask()
|
||||
{
|
||||
lock (this.SyncObject1)
|
||||
{
|
||||
if (this.Value != "put")
|
||||
{
|
||||
Monitor.Wait(this.SyncObject1);
|
||||
}
|
||||
|
||||
this.Value = "taken";
|
||||
}
|
||||
}
|
||||
|
||||
private void PutTask()
|
||||
{
|
||||
lock (this.SyncObject1)
|
||||
{
|
||||
this.Value = "put";
|
||||
Monitor.Pulse(this.SyncObject1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче