[RUNTIME][PYTHON] More compatibility in ndarray (#463)

This commit is contained in:
Tianqi Chen 2017-09-18 23:11:19 -07:00 коммит произвёл GitHub
Родитель 0220abbafa
Коммит 2607a83619
2 изменённых файлов: 13 добавлений и 1 удалений

Просмотреть файл

@ -165,6 +165,10 @@ class NDArrayBase(_NDArrayBase):
arr : NDArray
Reference to self.
"""
if isinstance(source_array, NDArrayBase):
source_array.copyto(self)
return self
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
@ -187,6 +191,14 @@ class NDArrayBase(_NDArrayBase):
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self
def __repr__(self):
res = "<tvm.NDArray shape={0}, {1}>\n".format(self.shape, self.context)
res += self.asnumpy().__repr__()
return res
def __str__(self):
return str(self.asnumpy())
def asnumpy(self):
"""Convert this array to numpy array

Просмотреть файл

@ -140,7 +140,7 @@ def array(arr, ctx=cpu(0)):
ret : NDArray
The created array
"""
if not isinstance(arr, _np.ndarray):
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)