gix gpu problem when copying ndarrayview
This commit is contained in:
Родитель
34e3ee090d
Коммит
6d564a5e50
|
@ -26,7 +26,7 @@ def _sanitize_value(shape, value, dtype, device, is_param=False):
|
|||
|
||||
#TODO: check whether this copy operation from cpu to gpu is not needed
|
||||
if device.type() != 0:
|
||||
ndav_cpu = utils.create_NDArrayView_from_NumPy(value)
|
||||
ndav_cpu = utils.create_NDArrayView_from_NumPy(value, dev=DeviceDescriptor.cpu_device())
|
||||
ndav = utils.create_NDArrayView(value.shape, data_type=cntk_dtype, dev=device)
|
||||
ndav.copy_from(ndav_cpu)
|
||||
else:
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
from cntk.ops import *
|
||||
|
||||
def linear_layer(input, output_dim):
|
||||
input_dim = input.shape()[0]
|
||||
input_dim = input.shape().dimensions()[0]
|
||||
times_param = parameter(shape=(input_dim, output_dim))
|
||||
bias_param = parameter(shape=(output_dim))
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче