зеркало из https://github.com/microsoft/hat.git
Adds support for verifying/benchmarking functions with layouts other than first major (#43)
* Adds support for function parameters with layouts other than first major
This commit is contained in:
Родитель
bce20498a3
Коммит
5367112387
|
@ -1,4 +1,5 @@
|
|||
import ctypes
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
@ -27,10 +28,12 @@ DTYPE_ENTRY = 1
|
|||
class ArgInfo:
|
||||
"""Extracts necessary information from the description of a function argument in a hat file"""
|
||||
hat_declared_type: str
|
||||
numpy_shape: Tuple[int]
|
||||
numpy_strides: Tuple[int]
|
||||
numpy_shape: Tuple[int, ...]
|
||||
numpy_strides: Tuple[int, ...]
|
||||
numpy_dtype: type
|
||||
element_num_bytes: int
|
||||
element_strides: Tuple[int, ...]
|
||||
total_byte_size: int
|
||||
ctypes_pointer_type: Any
|
||||
usage: hat_file.UsageType = None
|
||||
|
||||
|
@ -45,7 +48,13 @@ class ArgInfo:
|
|||
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
|
||||
self.numpy_dtype = np.dtype(ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY])
|
||||
self.element_num_bytes = self.numpy_dtype.itemsize
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description.affine_map])
|
||||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
|
||||
def product(l):
|
||||
return reduce(lambda x1, x2: x1*x2, l)
|
||||
|
||||
self.total_byte_size = self.element_num_bytes * self.numpy_shape[0] * product(self.element_strides)
|
||||
|
||||
|
||||
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name
|
||||
|
|
|
@ -84,17 +84,13 @@ def get_func_from_ptx(ptx, func_name):
|
|||
return kernel
|
||||
|
||||
|
||||
def _arg_size(arg_info: ArgInfo):
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
||||
|
||||
|
||||
def _cuda_transfer_mem(usage, func, source_args: List, dest_args: List, arg_infos: List[ArgInfo], stream=None):
|
||||
for source_arg, dest_arg, arg_info in zip(source_args, dest_args, arg_infos):
|
||||
if usage in arg_info.usage.value:
|
||||
if stream:
|
||||
err, = func(dest_arg, source_arg, _arg_size(arg_info), stream)
|
||||
err, = func(dest_arg, source_arg, arg_info.total_byte_size, stream)
|
||||
else:
|
||||
err, = func(dest_arg, source_arg, _arg_size(arg_info))
|
||||
err, = func(dest_arg, source_arg, arg_info.total_byte_size)
|
||||
ASSERT_DRV(err)
|
||||
|
||||
|
||||
|
@ -124,7 +120,7 @@ def allocate_cuda_mem(arg_infos: List[ArgInfo], stream=None):
|
|||
device_mem = []
|
||||
|
||||
for arg in arg_infos:
|
||||
size = _arg_size(arg)
|
||||
size = arg.total_byte_size
|
||||
err, mem = cuda.cuMemAllocAsync(size, stream) if stream else cuda.cuMemAlloc(size)
|
||||
ASSERT_DRV(err)
|
||||
device_mem.append(mem)
|
||||
|
|
|
@ -33,23 +33,22 @@ from . import hat_package
|
|||
from .arg_info import ArgInfo
|
||||
|
||||
|
||||
def generate_input_sets_for_func(func: hat_file.Function,
|
||||
input_sets_minimum_size_MB: int = 0,
|
||||
num_additional: int = 0):
|
||||
def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
|
||||
parameters = list(map(ArgInfo, func.arguments))
|
||||
shapes_to_sizes = [
|
||||
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters
|
||||
]
|
||||
set_size = reduce(
|
||||
lambda x, y: x + y,
|
||||
map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes,
|
||||
parameters))
|
||||
shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
|
||||
set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
|
||||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
|
||||
|
||||
def product(l):
|
||||
return reduce(lambda x1, x2: x1 * x2, l)
|
||||
|
||||
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
|
||||
set_size) + 1 + num_additional
|
||||
input_sets = [[
|
||||
np.random.random(p.numpy_shape).astype(p.numpy_dtype)
|
||||
for p in parameters
|
||||
np.lib.stride_tricks.as_strided(
|
||||
np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype),
|
||||
shape=p.numpy_shape,
|
||||
strides=p.numpy_strides
|
||||
) for p in parameters
|
||||
] for _ in range(num_input_sets)]
|
||||
|
||||
return input_sets[0] if len(input_sets) == 1 else input_sets
|
||||
|
@ -57,16 +56,11 @@ def generate_input_sets_for_func(func: hat_file.Function,
|
|||
|
||||
def generate_input_sets_for_hat_file(hat_path):
|
||||
t = hat_file.HATFile.Deserialize(hat_path)
|
||||
return {
|
||||
func_name: generate_input_sets_for_func(func_desc)
|
||||
for func_name, func_desc in t.function_map.items()
|
||||
}
|
||||
return {func_name: generate_input_sets_for_func(func_desc)
|
||||
for func_name, func_desc in t.function_map.items()}
|
||||
|
||||
|
||||
def load(
|
||||
hat_path,
|
||||
try_dynamic_load=True
|
||||
) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
|
||||
def load(hat_path, try_dynamic_load=True) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
|
||||
"""
|
||||
Returns a HATPackage object loaded from the path provided. If
|
||||
`try_dynamic_load` is True, a non-empty dictionary object that can be used
|
||||
|
|
|
@ -12,10 +12,6 @@ from .pyhip.hip import *
|
|||
from .pyhip.hiprtc import *
|
||||
|
||||
|
||||
def _arg_size(arg_info: ArgInfo):
|
||||
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
|
||||
|
||||
|
||||
def initialize_rocm():
|
||||
# Initialize ROCM Driver API
|
||||
hipInit(0)
|
||||
|
@ -47,7 +43,7 @@ def get_func_from_rocm_program(rocm_program, func_name):
|
|||
def allocate_rocm_mem(arg_infos: List[ArgInfo]):
|
||||
device_mem = []
|
||||
for arg in arg_infos:
|
||||
mem = hipMalloc(_arg_size(arg))
|
||||
mem = hipMalloc(arg.total_byte_size)
|
||||
device_mem.append(mem)
|
||||
|
||||
return device_mem
|
||||
|
@ -61,13 +57,13 @@ def free_rocm_mem(args):
|
|||
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'input' in arg_info.usage.value:
|
||||
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=_arg_size(arg_info))
|
||||
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=arg_info.total_byte_size)
|
||||
|
||||
|
||||
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
|
||||
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
|
||||
if 'output' in arg_info.usage.value:
|
||||
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=_arg_size(arg_info))
|
||||
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=arg_info.total_byte_size)
|
||||
|
||||
|
||||
def device_args_to_ptr_list(device_args: List):
|
||||
|
|
Загрузка…
Ссылка в новой задаче