Adds support for verifying/benchmarking functions with layouts other than first major (#43)

* Adds support for function parameters with layouts other than first major
This commit is contained in:
Kern Handa 2022-04-14 22:15:08 -07:00 коммит произвёл GitHub
Родитель bce20498a3
Коммит 5367112387
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 34 добавлений и 39 удалений

Просмотреть файл

@ -1,4 +1,5 @@
import ctypes import ctypes
from functools import reduce
import numpy as np import numpy as np
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
@ -27,10 +28,12 @@ DTYPE_ENTRY = 1
class ArgInfo: class ArgInfo:
"""Extracts necessary information from the description of a function argument in a hat file""" """Extracts necessary information from the description of a function argument in a hat file"""
hat_declared_type: str hat_declared_type: str
numpy_shape: Tuple[int] numpy_shape: Tuple[int, ...]
numpy_strides: Tuple[int] numpy_strides: Tuple[int, ...]
numpy_dtype: type numpy_dtype: type
element_num_bytes: int element_num_bytes: int
element_strides: Tuple[int, ...]
total_byte_size: int
ctypes_pointer_type: Any ctypes_pointer_type: Any
usage: hat_file.UsageType = None usage: hat_file.UsageType = None
@ -45,7 +48,13 @@ class ArgInfo:
self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY]) self.ctypes_pointer_type = ctypes.POINTER(ARG_TYPES[self.hat_declared_type][CTYPE_ENTRY])
self.numpy_dtype = np.dtype(ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY]) self.numpy_dtype = np.dtype(ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY])
self.element_num_bytes = self.numpy_dtype.itemsize self.element_num_bytes = self.numpy_dtype.itemsize
self.numpy_strides = tuple([self.element_num_bytes * x for x in param_description.affine_map]) self.element_strides = param_description.affine_map
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
def product(l):
return reduce(lambda x1, x2: x1*x2, l)
self.total_byte_size = self.element_num_bytes * self.numpy_shape[0] * product(self.element_strides)
# TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name # TODO: Update this to take a HATFunction instead, instead of arg_infos and function_name

Просмотреть файл

@ -84,17 +84,13 @@ def get_func_from_ptx(ptx, func_name):
return kernel return kernel
def _arg_size(arg_info: ArgInfo):
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
def _cuda_transfer_mem(usage, func, source_args: List, dest_args: List, arg_infos: List[ArgInfo], stream=None): def _cuda_transfer_mem(usage, func, source_args: List, dest_args: List, arg_infos: List[ArgInfo], stream=None):
for source_arg, dest_arg, arg_info in zip(source_args, dest_args, arg_infos): for source_arg, dest_arg, arg_info in zip(source_args, dest_args, arg_infos):
if usage in arg_info.usage.value: if usage in arg_info.usage.value:
if stream: if stream:
err, = func(dest_arg, source_arg, _arg_size(arg_info), stream) err, = func(dest_arg, source_arg, arg_info.total_byte_size, stream)
else: else:
err, = func(dest_arg, source_arg, _arg_size(arg_info)) err, = func(dest_arg, source_arg, arg_info.total_byte_size)
ASSERT_DRV(err) ASSERT_DRV(err)
@ -124,7 +120,7 @@ def allocate_cuda_mem(arg_infos: List[ArgInfo], stream=None):
device_mem = [] device_mem = []
for arg in arg_infos: for arg in arg_infos:
size = _arg_size(arg) size = arg.total_byte_size
err, mem = cuda.cuMemAllocAsync(size, stream) if stream else cuda.cuMemAlloc(size) err, mem = cuda.cuMemAllocAsync(size, stream) if stream else cuda.cuMemAlloc(size)
ASSERT_DRV(err) ASSERT_DRV(err)
device_mem.append(mem) device_mem.append(mem)

Просмотреть файл

@ -33,23 +33,22 @@ from . import hat_package
from .arg_info import ArgInfo from .arg_info import ArgInfo
def generate_input_sets_for_func(func: hat_file.Function, def generate_input_sets_for_func(func: hat_file.Function, input_sets_minimum_size_MB: int = 0, num_additional: int = 0):
input_sets_minimum_size_MB: int = 0,
num_additional: int = 0):
parameters = list(map(ArgInfo, func.arguments)) parameters = list(map(ArgInfo, func.arguments))
shapes_to_sizes = [ shapes_to_sizes = [reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters]
reduce(lambda x, y: x * y, p.numpy_shape) for p in parameters set_size = reduce(lambda x, y: x + y, map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, parameters))
]
set_size = reduce( num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 // set_size) + 1 + num_additional
lambda x, y: x + y,
map(lambda size, p: size * p.element_num_bytes, shapes_to_sizes, def product(l):
parameters)) return reduce(lambda x1, x2: x1 * x2, l)
num_input_sets = (input_sets_minimum_size_MB * 1024 * 1024 //
set_size) + 1 + num_additional
input_sets = [[ input_sets = [[
np.random.random(p.numpy_shape).astype(p.numpy_dtype) np.lib.stride_tricks.as_strided(
for p in parameters np.random.rand(p.numpy_shape[0] * product(p.element_strides)).astype(p.numpy_dtype),
shape=p.numpy_shape,
strides=p.numpy_strides
) for p in parameters
] for _ in range(num_input_sets)] ] for _ in range(num_input_sets)]
return input_sets[0] if len(input_sets) == 1 else input_sets return input_sets[0] if len(input_sets) == 1 else input_sets
@ -57,16 +56,11 @@ def generate_input_sets_for_func(func: hat_file.Function,
def generate_input_sets_for_hat_file(hat_path): def generate_input_sets_for_hat_file(hat_path):
t = hat_file.HATFile.Deserialize(hat_path) t = hat_file.HATFile.Deserialize(hat_path)
return { return {func_name: generate_input_sets_for_func(func_desc)
func_name: generate_input_sets_for_func(func_desc) for func_name, func_desc in t.function_map.items()}
for func_name, func_desc in t.function_map.items()
}
def load( def load(hat_path, try_dynamic_load=True) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
hat_path,
try_dynamic_load=True
) -> Tuple[hat_package.HATPackage, Union[hat_package.AttributeDict, None]]:
""" """
Returns a HATPackage object loaded from the path provided. If Returns a HATPackage object loaded from the path provided. If
`try_dynamic_load` is True, a non-empty dictionary object that can be used `try_dynamic_load` is True, a non-empty dictionary object that can be used

Просмотреть файл

@ -12,10 +12,6 @@ from .pyhip.hip import *
from .pyhip.hiprtc import * from .pyhip.hiprtc import *
def _arg_size(arg_info: ArgInfo):
return arg_info.element_num_bytes * reduce(lambda x, y: x * y, arg_info.numpy_shape)
def initialize_rocm(): def initialize_rocm():
# Initialize ROCM Driver API # Initialize ROCM Driver API
hipInit(0) hipInit(0)
@ -47,7 +43,7 @@ def get_func_from_rocm_program(rocm_program, func_name):
def allocate_rocm_mem(arg_infos: List[ArgInfo]): def allocate_rocm_mem(arg_infos: List[ArgInfo]):
device_mem = [] device_mem = []
for arg in arg_infos: for arg in arg_infos:
mem = hipMalloc(_arg_size(arg)) mem = hipMalloc(arg.total_byte_size)
device_mem.append(mem) device_mem.append(mem)
return device_mem return device_mem
@ -61,13 +57,13 @@ def free_rocm_mem(args):
def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]): def transfer_mem_host_to_rocm(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos): for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'input' in arg_info.usage.value: if 'input' in arg_info.usage.value:
hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=_arg_size(arg_info)) hipMemcpy_htod(dst=device_arg, src=host_arg.ctypes.data, count=arg_info.total_byte_size)
def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]): def transfer_mem_rocm_to_host(device_args: List, host_args: List[np.array], arg_infos: List[ArgInfo]):
for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos): for device_arg, host_arg, arg_info in zip(device_args, host_args, arg_infos):
if 'output' in arg_info.usage.value: if 'output' in arg_info.usage.value:
hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=_arg_size(arg_info)) hipMemcpy_dtoh(dst=host_arg.ctypes.data, src=device_arg, count=arg_info.total_byte_size)
def device_args_to_ptr_list(device_args: List): def device_args_to_ptr_list(device_args: List):