зеркало из https://github.com/microsoft/hat.git
HAT support for dynamic shared memory allocation for GPU
This commit is contained in:
Родитель
12e9f2f7f8
Коммит
ddbcd98896
|
@ -185,7 +185,6 @@ class CudaCallableFunc(CallableFunc):
|
|||
self.hat_func = func
|
||||
self.func_info = FunctionInfo(func)
|
||||
self.kernel = None
|
||||
self.launch_params = func.launch_parameters
|
||||
self.device_mem = None
|
||||
self.ptrs = None
|
||||
self.start_event = None
|
||||
|
@ -223,8 +222,8 @@ class CudaCallableFunc(CallableFunc):
|
|||
for _ in range(warmup_iters):
|
||||
err, = cuda.cuLaunchKernel(
|
||||
self.kernel,
|
||||
*self.launch_params, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
*self.hat_func.launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
self.hat_func.dynamic_shared_mem_bytes,
|
||||
0, # stream
|
||||
self.ptrs.ctypes.data, # kernel arguments
|
||||
0, # extra (ignore)
|
||||
|
@ -242,8 +241,8 @@ class CudaCallableFunc(CallableFunc):
|
|||
for _ in range(iters):
|
||||
err, = cuda.cuLaunchKernel(
|
||||
self.kernel,
|
||||
*self.launch_params, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
*self.hat_func.launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
self.hat_func.dynamic_shared_mem_bytes,
|
||||
0, # stream
|
||||
self.ptrs.ctypes.data, # kernel arguments
|
||||
0, # extra (ignore)
|
||||
|
|
|
@ -211,6 +211,7 @@ class Function(AuxiliarySupportedTable):
|
|||
|
||||
# optional
|
||||
launch_parameters: list = field(default_factory=list)
|
||||
dynamic_shared_mem_bytes: int = 0
|
||||
launches: str = ""
|
||||
provider: str = ""
|
||||
runtime: str = ""
|
||||
|
@ -231,6 +232,9 @@ class Function(AuxiliarySupportedTable):
|
|||
if self.launch_parameters:
|
||||
table.add("launch_parameters", self.launch_parameters)
|
||||
|
||||
if self.dynamic_shared_mem_bytes:
|
||||
table.add("dynamic_shared_mem_bytes", self.dynamic_shared_mem_bytes)
|
||||
|
||||
if self.launches:
|
||||
table.add("launches", self.launches)
|
||||
|
||||
|
@ -254,6 +258,8 @@ class Function(AuxiliarySupportedTable):
|
|||
|
||||
launch_parameters = function_table["launch_parameters"] if "launch_parameters" in function_table else []
|
||||
|
||||
dynamic_shared_mem_bytes = function_table["dynamic_shared_mem_bytes"] if "dynamic_shared_mem_bytes" in function_table else 0
|
||||
|
||||
launches = function_table["launches"] if "launches" in function_table else ""
|
||||
|
||||
provider = function_table["provider"] if "provider" in function_table else ""
|
||||
|
@ -269,6 +275,7 @@ class Function(AuxiliarySupportedTable):
|
|||
arguments=arguments,
|
||||
return_info=return_info,
|
||||
launch_parameters=launch_parameters,
|
||||
dynamic_shared_mem_bytes=dynamic_shared_mem_bytes,
|
||||
launches=launches,
|
||||
provider=provider,
|
||||
runtime=runtime,
|
||||
|
|
|
@ -95,7 +95,6 @@ class RocmCallableFunc(CallableFunc):
|
|||
self.hat_func = func
|
||||
self.func_info = FunctionInfo(func)
|
||||
self.kernel = None
|
||||
self.launch_params = func.launch_parameters
|
||||
self.device_mem = None
|
||||
self.ptrs = None
|
||||
self.stream = None
|
||||
|
@ -142,8 +141,8 @@ class RocmCallableFunc(CallableFunc):
|
|||
for _ in range(warmup_iters):
|
||||
hipModuleLaunchKernel(
|
||||
self.kernel,
|
||||
*self.launch_params, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
*self.hat_func.launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
self.hat_func.dynamic_shared_mem_bytes,
|
||||
0, # stream
|
||||
self.data, # data
|
||||
)
|
||||
|
@ -156,8 +155,8 @@ class RocmCallableFunc(CallableFunc):
|
|||
for _ in range(iters):
|
||||
hipModuleLaunchKernel(
|
||||
self.kernel,
|
||||
*self.launch_params, # [ grid[x-z], block[x-z] ]
|
||||
0, # dynamic shared memory
|
||||
*self.hat_func.launch_parameters, # [ grid[x-z], block[x-z] ]
|
||||
self.hat_func.dynamic_shared_mem_bytes, # dynamic shared memory
|
||||
0, # stream
|
||||
self.data, # data
|
||||
)
|
||||
|
|
|
@ -101,6 +101,11 @@ version = "0.0.0.3"
|
|||
type = "array"
|
||||
optional = true
|
||||
|
||||
# The dynamic shared memory size in bytes to be allocated for this device function
|
||||
[types.functionType.dynamic_shared_mem_bytes]
|
||||
type = "integer"
|
||||
optional = true
|
||||
|
||||
# The function that is launched by this function
|
||||
[types.functionType.launches]
|
||||
type = "string"
|
||||
|
|
Загрузка…
Ссылка в новой задаче