зеркало из https://github.com/microsoft/hat.git
Adds support for verifying/benchmarking functions with layouts other than first major (#43)
* Adds support for function parameters with layouts other than first major
This commit is contained in:
Родитель
bce20498a3
Коммит
5367112387
|
@ -1,4 +1,5 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
|
from functools import reduce
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -27,10 +28,12 @@ DTYPE_ENTRY = 1
|
||||||
class ArgInfo:
|
class ArgInfo:
|
||||||
"""Extracts necessary information from the description of a function argument in a hat file"""
|
"""Extracts necessary information from the description of a function argument in a hat file"""
|
||||||
hat_declared_type: str
|
hat_declared_type: str
|
||||||
numpy_shape: Tuple[int]
|
numpy_shape: Tuple[int, ...]
|
||||||
numpy_strides: Tuple[int]
|
numpy_strides: Tuple[int, ...]
|
||||||
numpy_dtype: type
|
numpy_dtype: type
|
||||||
element_num_bytes: int
|
element_num_bytes: int
|
||||||
|
element_strides: Tuple[int, ...]
|
||||||
|
total_byte_size: int
|
||||||
ctypes_pointer_type: Any
|
ctypes_pointer_type: Any
|
||||||
usage: hat_file.UsageType = None
|
usage: hat_file.UsageType = None
|
||||||
|
|
||||||
|
@ -45,7 +48,13 @@ class ArgInfo:
|
||||||
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
|
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
|
||||||
self.numpy_dtype = np.dtype(ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY])
|
self.numpy_dtype = np.dtype(ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY])
|
||||||
self.element_num_bytes = self.numpy_dtype.itemsize
|
self.element_num_bytes = self.numpy_dtype.itemsize
|
||||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description.affine_map])
|
self.element_strides = param_description.affine_map
|
||||||
|
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||||
|
|
||||||
|
def product(l):
|
||||||
|
return reduce(lambda x1, x2: x1*x2, l)
|
||||||
|
|
||||||
|
self.total_byte_size = self.element_num_bytes * self.numpy_shape[0] * product(self.element_strides)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
||||||
|
|
|
@ -84,17 +84,13 @@ def get_func_from_ptx(ptx, func_name):
|
||||||
return kernel
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
def _arg_size(arg_info: ArgInfo):
|
|
||||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _cuda_transfer_mem(usage, func, source_args: List, dest_args: List, arg_infos: List[ArgInfo], stream=None):
|
def _cuda_transfer_mem(usage, func, source_args: List, dest_args: List, arg_infos: List[ArgInfo], stream=None):
|
||||||
for source_arg, dest_arg, arg_info in zip(source_args, dest_args, arg_infos):
|
for source_arg, dest_arg, arg_info in zip(source_args, dest_args, arg_infos):
|
||||||
if usage in arg_info.usage.value:
|
if usage in arg_info.usage.value:
|
||||||
if stream:
|
if stream:
|
||||||
err, = func(dest_arg, source_arg, _arg_size(arg_info), stream)
|
err, = func(dest_arg, source_arg, arg_info.total_byte_size, stream)
|
||||||
else:
|
else:
|
||||||
err, = func(dest_arg, source_arg, _arg_size(arg_info))
|
err, = func(dest_arg, source_arg, arg_info.total_byte_size)
|
||||||
ASSERT_DRV(err)
|
ASSERT_DRV(err)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +120,7 @@ def allocate_cuda_mem(arg_infos: List[ArgInfo], stream=None):
|
||||||
device_mem = []
|
device_mem = []
|
||||||
|
|
||||||
for arg in arg_infos:
|
for arg in arg_infos:
|
||||||
size = _arg_size(arg)
|
size = arg.total_byte_size
|
||||||
err, mem = cuda.cuMemAllocAsync(size, stream) if stream else cuda.cuMemAlloc(size)
|
err, mem = cuda.cuMemAllocAsync(size, stream) if stream else cuda.cuMemAlloc(size)
|
||||||
ASSERT_DRV(err)
|
ASSERT_DRV(err)
|
||||||
device_mem.append(mem)
|
device_mem.append(mem)
|
||||||
|
|
|
@ -33,23 +33,22 @@ from . import hat_package
|
||||||
from .arg_info import ArgInfo
|
from .arg_info import ArgInfo
|
||||||
|
|
||||||
|
|
||||||
def generate_input_sets_for_func(func: hat_file.Function,
|
def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
|
||||||
input_sets_minimum_size_MB: int = 0,
|
|
||||||
num_additional: int = 0):
|
|
||||||
parameters = list(map(ArgInfo, func.arguments))
|
parameters = list(map(ArgInfo, func.arguments))
|
||||||
shapes_to_sizes = [
|
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
|
||||||
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
|
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
|
||||||
]
|
|
||||||
set_size = reduce(
|
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
|
||||||
lambda x, y: x + y,
|
|
||||||
map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes,
|
def product(l):
|
||||||
parameters))
|
return reduce(lambda x1, x2: x1 * x2, l)
|
||||||
|
|
||||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
|
|
||||||
set_size) + 1 + num_additional
|
|
||||||
input_sets = [[
|
input_sets = [[
|
||||||
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
|
np.lib.stride_tricks.as_strided(
|
||||||
for p in parameters
|
np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype),
|
||||||
|
shape=p.numpy_shape,
|
||||||
|
strides=p.numpy_strides
|
||||||
|
) for p in parameters
|
||||||
] for _ in range(num_input_sets)]
|
] for _ in range(num_input_sets)]
|
||||||
|
|
||||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
return input_sets[0] if len(input_sets) == 1 else input_sets
|
||||||
|
@ -57,16 +56,11 @@ def generate_input_sets_for_func(func: hat_file.Function,
|
||||||
|
|
||||||
def generate_input_sets_for_hat_file(hat_path):
|
def generate_input_sets_for_hat_file(hat_path):
|
||||||
t = hat_file.HATFile.Deserialize(hat_path)
|
t = hat_file.HATFile.Deserialize(hat_path)
|
||||||
return {
|
return {func_name: generate_input_sets_for_func(func_desc)
|
||||||
func_name: generate_input_sets_for_func(func_desc)
|
for func_name, func_desc in t.function_map.items()}
|
||||||
for func_name, func_desc in t.function_map.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load(
|
def load(hat_path, try_dynamic_load=True) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
|
||||||
hat_path,
|
|
||||||
try_dynamic_load=True
|
|
||||||
) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
|
|
||||||
"""
|
"""
|
||||||
Returns a HATPackage object loaded from the path provided. If
|
Returns a HATPackage object loaded from the path provided. If
|
||||||
`try_dynamic_load` is True, a non-empty dictionary object that can be used
|
`try_dynamic_load` is True, a non-empty dictionary object that can be used
|
||||||
|
|
|
@ -12,10 +12,6 @@ from .pyhip.hip import *
|
||||||
from .pyhip.hiprtc import *
|
from .pyhip.hiprtc import *
|
||||||
|
|
||||||
|
|
||||||
def _arg_size(arg_info: ArgInfo):
|
|
||||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_rocm():
|
def initialize_rocm():
|
||||||
# Initialize ROCM Driver API
|
# Initialize ROCM Driver API
|
||||||
hipInit(0)
|
hipInit(0)
|
||||||
|
@ -47,7 +43,7 @@ def get_func_from_rocm_program(rocm_program, func_name):
|
||||||
def allocate_rocm_mem(arg_infos: List[ArgInfo]):
|
def allocate_rocm_mem(arg_infos: List[ArgInfo]):
|
||||||
device_mem = []
|
device_mem = []
|
||||||
for arg in arg_infos:
|
for arg in arg_infos:
|
||||||
mem = hipMalloc(_arg_size(arg))
|
mem = hipMalloc(arg.total_byte_size)
|
||||||
device_mem.append(mem)
|
device_mem.append(mem)
|
||||||
|
|
||||||
return device_mem
|
return device_mem
|
||||||
|
@ -61,13 +57,13 @@ def free_rocm_mem(args):
|
||||||
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||||
if 'input' in arg_info.usage.value:
|
if 'input' in arg_info.usage.value:
|
||||||
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=_arg_size(arg_info))
|
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=arg_info.total_byte_size)
|
||||||
|
|
||||||
|
|
||||||
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||||
if 'output' in arg_info.usage.value:
|
if 'output' in arg_info.usage.value:
|
||||||
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=_arg_size(arg_info))
|
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=arg_info.total_byte_size)
|
||||||
|
|
||||||
|
|
||||||
def device_args_to_ptr_list(device_args: List):
|
def device_args_to_ptr_list(device_args: List):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче