зеркало из https://github.com/microsoft/hat.git
Adds handling in ArgInfo to treat empty shape as scalar arrays (#59)
This commit is contained in:
Родитель
4eb8fdc6c0
Коммит
004632c406
|
@ -58,11 +58,18 @@ class ArgInfo:
|
|||
dtype_entry = ARG_TYPES[self.hat_declared_type][DTYPE_ENTRY]
|
||||
self.numpy_dtype = self._get_type(dtype_entry)
|
||||
self.element_num_bytes = 2 if dtype_entry == "bfloat16" else self.numpy_dtype.itemsize
|
||||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
|
||||
major_dim = self.element_strides.index(max(self.element_strides))
|
||||
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
|
||||
if self.numpy_shape:
|
||||
self.element_strides = param_description.affine_map
|
||||
self.numpy_strides = tuple([self.element_num_bytes * x for x in self.element_strides])
|
||||
|
||||
major_dim = self.element_strides.index(max(self.element_strides))
|
||||
self.total_element_count = self.numpy_shape[major_dim] * self.element_strides[major_dim]
|
||||
|
||||
else:
|
||||
self.element_strides = self.numpy_strides = self.numpy_shape = [1]
|
||||
self.total_element_count = 1
|
||||
|
||||
self.total_byte_size = self.element_num_bytes * self.total_element_count
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче