Adds handling in ArgInfo to treat empty shape as scalar arrays (#59)

This commit is contained in:
Kern Handa 2022-08-03 11:28:17 -07:00 коммит произвёл GitHub
Родитель 4eb8fdc6c0
Коммит 004632c406
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 11 добавлений и 4 удалений

Просмотреть файл

@ -58,11 +58,18 @@ class ArgInfo:
dtype_entry = ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY]
self.numpy_dtype = self._get_type(dtype_entry)
self.element_num_bytes = 2 if dtype_entry == "bfloat16" else self.numpy_dtype.itemsize
self.element_strides = param_description.affine_map
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
major_dim = self.element_strides.index(max(self.element_strides))
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
if self.numpy_shape:
self.element_strides = param_description.affine_map
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
major_dim = self.element_strides.index(max(self.element_strides))
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
else:
self.element_strides = self.numpy_strides = self.numpy_shape = [1]
self.total_element_count = 1
self.total_byte_size = self.element_num_bytes * self.total_element_count