Use NDShape.unknown instead of None

This commit is contained in:
Emad Barsoum 2016-11-10 09:48:16 -08:00
Родитель 74a7e0efd3
Коммит fdcc61a896
3 изменённых файлов: 6 добавлений и 12 удалений

Просмотреть файл

@ -387,13 +387,11 @@ fail:
%typecheck(1000) CNTK::NDShape const &, CNTK::NDShape {
// '1000' is the typecheck precedence code. It means: check after basic
// types, but before arrays. See: http://www.swig.org/Doc1.3/Typemaps.html#Typemaps_overloading
$1 = (($input == Py_None) || PyTuple_Check($input)) ? 1 : 0;
$1 = PyTuple_Check($input) ? 1 : 0;
}
%typemap(in) CNTK::NDShape const & {
if ($input == Py_None) {
$1 = new CNTK::NDShape(CNTK::NDShape::Unknown);
} else if (PyTuple_Check($input)) {
if (PyTuple_Check($input)) {
size_t rank = PyTuple_Size($input);
std::vector<size_t> dimensions(rank);
for (size_t i=0; i<rank; i++)

Просмотреть файл

@ -177,7 +177,7 @@ def Convolution(filter_shape, # e.g. (3,3)
# AveragePooling and GlobalAveragePooling
#
# Setting the filter_shape to None, mean global pooling.
from cntk.cntk_py import PoolingType_Max, PoolingType_Average
from cntk.cntk_py import PoolingType_Max, PoolingType_Average, NDShape
def Pooling(op, # PoolingType_Max or _Average
filter_shape, # e.g. (3,3)
strides=1,
@ -207,11 +207,11 @@ def AveragePooling(filter_shape, # e.g. (3,3)
# GlobalMaxPooling
def GlobalMaxPooling():
return Pooling(PoolingType_Max, None, pad=False)
return Pooling(PoolingType_Max, NDShape.unknown.dimensions(), pad=False)
# GlobalAveragePooling
def GlobalAveragePooling():
return Pooling(PoolingType_Average, None, pad=False)
return Pooling(PoolingType_Average, NDShape.unknown.dimensions(), pad=False)
# Recurrence() -- run a block recurrently over a time sequence
def Recurrence(over, go_backwards=False, initial_state=initial_state_default_or_None):

Просмотреть файл

@ -183,12 +183,8 @@ def get_temp_filename(directory=None):
def sanitize_shape(shape):
"""
If shape is scalar, it creates a tuple out of it. If the shape is None, then return it as it is, it will be mapped as NDShape::Unknown.
If shape is scalar, it creates a tuple out of it.
"""
# Unknown shape
if shape is None:
return shape
return _as_tuple(shape)