bindings/python/doc: fix most Sphinx warnings and some misc. errors

This commit is contained in:
Mark Hillebrand 2017-03-14 21:59:39 +01:00
Родитель fbc6ba086b
Коммит 09c71fe83a
24 изменённых файлов: 129 добавлений и 124 удалений

4
.gitignore поставляемый
Просмотреть файл

@ -218,6 +218,10 @@ bindings/python/cntk/libs/
bindings/python/cntk/cntk_py_wrap.cpp
bindings/python/cntk/cntk_py_wrap.h
bindings/python/dist/
bindings/python/doc/cntk.*.rst
bindings/python/doc/cntk.rst
bindings/python/doc/modules.rst
bindings/python/doc/_build
# Auto-generated sources from CNTK.proto
Source/CNTKv2LibraryDll/proto/CNTK.pb.cc

Просмотреть файл

@ -190,10 +190,11 @@ class Value(cntk_py.Value):
dtype: data type (np.float32 or np.float64)
batch: batch input for `var`.
It can be:
* a pure Python structure (list of lists, ...),
* a list of NumPy arrays or SciPy sparse CSR matrices
* a :class:`~cntk.core.Value` object (e.g. returned by :func:`one_hot`)
seq_starts (list of `bool`s or None): if None, every sequence is
seq_starts (list of `bool`\ s or None): if None, every sequence is
treated as a new sequence. Otherwise, it is interpreted as a list of
Booleans that tell whether a sequence is a new sequence (`True`) or a
continuation of the sequence in the same slot of the previous
@ -277,12 +278,13 @@ class Value(cntk_py.Value):
Args:
var (:class:`~cntk.ops.variables.Variable`): variable into which
``data`` is passed
data: data for `var`
data: data for `var`.
It can be:
* a single NumPy array denoting the full minibatch
* a list of NumPy arrays or SciPy sparse CSR matrices
* a single NumPy array denoting one parameter or constant
seq_starts (list of `bool`s or None): if None, every sequence is
seq_starts (list of `bool`\ s or None): if None, every sequence is
treated as a new sequence. Otherwise, it is interpreted as a list of
Booleans that tell whether a sequence is a new sequence (`True`) or a
continuation of the sequence in the same slot of the previous
@ -446,6 +448,7 @@ class Value(cntk_py.Value):
'''
The mask matrix of this value. Each row denotes a sequence with its
elements describing the mask of the element:
* 2: beginning of sequence (e.g. an LSTM would be reset)
* 1: valid element
* 0: invalid element

Просмотреть файл

@ -63,9 +63,12 @@ def get_default_override(function, **kwargs):
Meant to be used inside functions that use this facility.
Args:
function: the function that calls this, e.g.:
``def Convolution(args, init=default_override_or(glorot_uniform()), activation=default_override_or(identity), pad=default_override_or(False)):
init = _get_default_override(Convolution, init=init) # pass default under the same name``
function: the function that calls this.
For example::
def Convolution(args, init=default_override_or(glorot_uniform()), activation=default_override_or(identity), pad=default_override_or(False)):
init = _get_default_override(Convolution, init=init) # pass default under the same name
'''
# parameter checking and casting
if len(kwargs) != 1:

Просмотреть файл

@ -231,6 +231,7 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
data.
* any other type: if node has a unique input, arguments is
mapped to this input.
For nodes with more than one input, only dict is allowed.
In both cases, every sample in the data will be interpreted

Просмотреть файл

@ -95,12 +95,14 @@ class MinibatchSource(cntk_py.MinibatchSource):
`randomization_window`. If `True`, the size of the randomization window is interpreted as a certain
number of samples, otherwise -- as a number of chunks. Similarly to `randomization_window`,
this parameter is ignored, when `randomize` is `False`
epoch_size (`int`, defaults to cntk.io.INFINITELY_REPEAT): number of samples as a scheduling unit.
epoch_size (`int`, defaults to :const:`~cntk.io.INFINITELY_REPEAT`): number of samples as a scheduling unit.
Parameters in the schedule change their values every `epoch_size`
samples. If no `epoch_size` is provided, this parameter is substituted
by the size of the full data sweep with infinte repeat, in which case the scheduling unit is
the entire data sweep (as indicated by the MinibatchSource) and parameters
change their values on the sweep-by-sweep basis specified by the schedule. **Important:** `click here <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`_ for a full description of this parameter.
change their values on the sweep-by-sweep basis specified by the schedule.
**Important:**
Click `here <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`_ for a full description of this parameter.
distributed_after (int, defaults to cntk.io.INFINITE_SAMPLES): sample count after which minibatch source becomes distributed
multithreaded_deserializer (`bool`, defaults to `None`): using multi threaded deserializer
frame_mode (`bool`, defaults to `False`): Specifies if data should be randomized and returned at the frame
@ -181,7 +183,9 @@ class MinibatchSource(cntk_py.MinibatchSource):
Args:
minibatch_size_in_samples (int): number of samples to retrieve for
the next minibatch. Must be > 0. **Important:** `click here <https://github.com/Microsoft/CNTK/wiki/BrainScript-minibatchSize-and-Python-minibatch_size_in_samples-in-CNTK>`_ for a full description of this parameter.
the next minibatch. Must be > 0.
**Important:**
Click `here <https://github.com/Microsoft/CNTK/wiki/BrainScript-epochSize-and-Python-epoch_size-in-CNTK>`_ for a full description of this parameter.
input_map (dict): mapping of :class:`~cntk.ops.variables.Variable`
to :class:`~cntk.cntk_py.StreamInformation` which will be used to convert the
returned data.
@ -417,6 +421,7 @@ def ImageDeserializer(filename, streams):
labels from a file of the form::
<full path to image> <tab> <numerical label (0-based class id)>
or::
sequenceId <tab> path <tab> label

Просмотреть файл

@ -35,9 +35,11 @@ def Sequential(layers, name=''):
def For(rng, constructor, name=''):
'''
Layer factory function to create a composite that applies a sequence of layers constructed with a constructor lambda(layer).
E.g.
For(range(3), lambda i: Dense(2000))
For(range(3), lambda: Dense(2000))
For example::
For(range(3), lambda i: Dense(2000))
For(range(3), lambda: Dense(2000))
'''
# Python 2.7 support requires us to use getargspec() instead of inspect
from inspect import getargspec

Просмотреть файл

@ -28,7 +28,7 @@ def Dense(shape, activation=default_override_or(identity), init=default_override
Dense(shape, activation=identity, init=glorot_uniform(),input_rank=None, map_rank=None, bias=True, init_bias=0, name='')
Layer factory function to create an instance of a fully-connected linear layer of the form
`activation(input @ W + b)` with weights `W` and bias `b`, and `activation` and `b` being optional.
`activation(input @ W + b)` with weights `W` and bias `b`, and `activation` and `b` being optional.
`shape` may describe a tensor as well.
A ``Dense`` layer instance owns its parameter tensors `W` and `b`, and exposes them as attributes ``.W`` and ``.b``.
@ -130,8 +130,9 @@ def Embedding(shape=None, init=default_override_or(glorot_uniform()), weights=No
The lookup table in this layer is learnable,
unless a user-specified one is supplied through the ``weights`` parameter.
For example, to use an existing embedding table from a file in numpy format, use this:
``Embedding(weights=np.load('PATH.npy'))``
For example, to use an existing embedding table from a file in numpy format, use this::
Embedding(weights=np.load('PATH.npy'))
To initialize a learnable lookup table with a given numpy array that is to be used as
the initial value, pass that array to the ``init`` parameter (not ``weights``).
@ -139,7 +140,7 @@ def Embedding(shape=None, init=default_override_or(glorot_uniform()), weights=No
An ``Embedding`` instance owns its weight parameter tensor `E`, and exposes it as an attribute ``.E``.
Example:
# learnable embedding
>>> # learnable embedding
>>> f = Embedding(5)
>>> x = Input(3)
>>> e = f(x)
@ -148,7 +149,7 @@ def Embedding(shape=None, init=default_override_or(glorot_uniform()), weights=No
>>> f.E.shape
(3, 5)
# user-supplied embedding
>>> # user-supplied embedding
>>> f = Embedding(weights=[[.5, .3, .1, .4, .2], [.7, .6, .3, .2, .9]])
>>> f.E.value
array([[ 0.5, 0.3, 0.1, 0.4, 0.2],
@ -653,19 +654,19 @@ def ConvolutionTranspose(filter_shape, # shape of receptive field, e.g. (
(3, 128, 3, 4)
Args:
filter_shape ((`int` or `tuple` of `int`s)): shape (spatial extent) of the receptive field, *not* including the input feature-map depth. E.g. (3,3) for a 2D convolution.
num_filters (`int`): number of filters (output feature-map depth), or ``()`` to denote scalar output items (output shape will have no depth axis).
filter_shape (`int` or tuple of `int`\ s): shape (spatial extent) of the receptive field, *not* including the input feature-map depth. E.g. (3,3) for a 2D convolution.
num_filters (int): number of filters (output feature-map depth), or ``()`` to denote scalar output items (output shape will have no depth axis).
activation (:class:`~cntk.ops.functions.Function`, optional): optional function to apply at the end, e.g. `relu`
init (scalar or NumPy array or :mod:`cntk.initializer`, default `glorot_uniform()`): initial value of weights `W`
pad (`bool` or `tuple` of `bool`s, default `False`): if `False`, then the filter will be shifted over the "valid"
init (scalar or NumPy array or :mod:`cntk.initializer`, default :func:`glorot_uniform`): initial value of weights `W`
pad (`bool` or tuple of `bool`\ s, default `False`): if `False`, then the filter will be shifted over the "valid"
area of input, that is, no value outside the area is used. If ``pad=True`` on the other hand,
the filter will be applied to all input positions, and positions outside the valid region will be considered containing zero.
Use a `tuple` to specify a per-axis value.
strides (`int` or `tuple` of `int`s, default `): stride of the convolution (increment when sliding the filter over the input). Use a `tuple` to specify a per-axis value.
sharing (`bool`, default True): weight sharing, must be True for now.
strides (`int` or tuple of `int`\ s, default 1): stride of the convolution (increment when sliding the filter over the input). Use a `tuple` to specify a per-axis value.
sharing (`bool`, default `True`): weight sharing, must be True for now.
bias (`bool`, optional, default `True`): the layer will have no bias if `False` is passed here
init_bias (scalar or NumPy array or :mod:`cntk.initializer`): initial value of weights `b`
output_shape ((`int` or `tuple` of `int`s)): output shape. When strides > 2, the output shape is non-deterministic. User can specify the wanted output shape. Note the
output_shape (`int` or tuple of `int`\ s): output shape. When strides > 2, the output shape is non-deterministic. User can specify the wanted output shape. Note the
specified shape must satisify the condition that if a convolution is perform from the output with the same setting, the result must have same shape as the input.
reduction_rank (`int`, default 1): must be 1 for now.
that is stored with tensor shape (H,W) instead of (1,H,W)

Просмотреть файл

@ -218,10 +218,14 @@ def UnfoldFrom(generator_function, map_state_function=identity, until_predicate=
'''
Layer factory function to create a function that implements the unfold() anamorphism. It creates a function that, starting with a seed input,
applies 'generator_function' repeatedly and emits the sequence of results. Depending on the recurrent block,
it may have this form:
`result = f(... f(f([g(input), initial_state])) ... )`
or this form:
`result = f(g(input), ... f(g(input), f(g(input), initial_state)) ... )`
it may have this form::
result = f(... f(f([g(input), initial_state])) ... )
or this form::
result = f(g(input), ... f(g(input), f(g(input), initial_state)) ... )
where `f` is `generator_function`.
An example use of this is sequence-to-sequence decoding, where `g(input)` is the sequence encoder,
`initial_state` is the sentence-start symbol, and `f` is the decoder. The first

Просмотреть файл

@ -462,11 +462,10 @@ def momentum_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): momentum schedule.
momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
<https://github.com/Microsoft/CNTK/wiki/BrainScript-SGD-Block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
defaults to 0.0
@ -514,10 +513,9 @@ def nesterov(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
parameters (list of parameters): list of network parameters to tune.
These can be obtained by the root operator's ``parameters``.
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or
:func:`momentum_as_time_constant_schedule`): momentum schedule.
momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
<https://github.com/Microsoft/CNTK/wiki/BrainScript-SGD-Block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
l1_regularization_weight (float, optional): the L1 regularization weight per sample,
@ -626,7 +624,7 @@ def fsadagrad(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
<https://github.com/Microsoft/CNTK/wiki/BrainScript-SGD-Block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
variance_momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
@ -680,7 +678,7 @@ def adam(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
<https://github.com/Microsoft/CNTK/wiki/BrainScript-SGD-Block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
variance_momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults
@ -742,7 +740,7 @@ def adam_sgd(parameters, lr, momentum, unit_gain=default_unit_gain_value(),
lr (output of :func:`learning_rate_schedule`): learning rate schedule.
momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): momentum schedule.
For additional information, please refer to the `wiki
<https://github.com/Microsoft/CNTK/wiki/SGD-block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
<https://github.com/Microsoft/CNTK/wiki/BrainScript-SGD-Block#converting-learning-rate-and-momentum-parameters-from-other-toolkits>`_.
unit_gain: when ``True``, momentum is interpreted as a unit-gain filter. Defaults
to the value returned by :func:`default_unit_gain_value`.
variance_momentum (output of :func:`momentum_schedule` or :func:`momentum_as_time_constant_schedule`): variance momentum schedule. Defaults

Просмотреть файл

@ -130,13 +130,13 @@ def plot(root, filename=None):
* for DOT output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`_
* for PNG, PDF, and SVG output: `pydot_ng <https://pypi.python.org/pypi/pydot-ng>`_
and `graphviz <[http://graphviz.org](http://graphviz.org)>_ (GraphViz executable has to be in the system's PATH).
and `graphviz <http://graphviz.org>`_ (GraphViz executable has to be in the system's PATH).
Args:
node (graph node): the node to start the journey from
filename (`str`, default None): file with extension '.dot', 'png', 'pdf', or 'svg'
to denote what format should be written. If `None` then nothing
will be plotted. Instead, and the returned string can be used to debug the graph.
to denote what format should be written. If `None` then nothing
will be plotted, and the returned string can be used to debug the graph.
Returns:
`str` describing the graph

Просмотреть файл

@ -30,7 +30,7 @@ def combine(operands, name=''):
'''
Create a new Function instance which just combines the outputs of the specified list of
'operands' Functions such that the 'Outputs' of the new 'Function' are union of the
'Outputs' of each of the specified 'operands' Functions. E.g. When creating a classification
'Outputs' of each of the specified 'operands' Functions. E.g., when creating a classification
model, typically the CrossEntropy loss Function and the ClassificationError Function comprise
the two roots of the computation graph which can be combined to create a single Function
with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
@ -881,9 +881,12 @@ def times(left, right, output_rank=1, infer_input_rank_to_map=TIMES_NO_INFERRED_
The operator '@' has been overloaded such that in Python 3.5 and later X @ W equals times(X, W).
For better performance on times operation on sequence which is followed by sequence.reduce_sum, use
infer_input_rank_to_map=TIMES_REDUCE_SEQUENCE_AXIS_WITHOUT_INFERRED_INPUT_RANK, i.e. replace following:
infer_input_rank_to_map=TIMES_REDUCE_SEQUENCE_AXIS_WITHOUT_INFERRED_INPUT_RANK, i.e. replace following::
sequence.reduce_sum(times(seq1, seq2))
with:
with::
times(seq1, seq2, infer_input_rank_to_map=TIMES_REDUCE_SEQUENCE_AXIS_WITHOUT_INFERRED_INPUT_RANK)
Example:
@ -1962,7 +1965,6 @@ def slice(x, axis, begin_index, end_index, name=''):
... [4, 5, 6]]]],dtype=np.float32)})
array([[[[ 1.],
[ 4.]]]], dtype=float32)
<BLANKLINE>
>>> # slice using constant
>>> data = np.asarray([[1, 2, -3],
@ -1973,7 +1975,6 @@ def slice(x, axis, begin_index, end_index, name=''):
>>> C.slice(x, 1, 0, 1).eval()
array([[ 1.],
[ 4.]], dtype=float32)
<BLANKLINE>
>>> # slice using the index overload
>>> data = np.asarray([[1, 2, -3],
@ -2485,7 +2486,7 @@ def dropout(x, dropout_rate=0.0, name=''):
Args:
x: input tensor
dropout_rate (float, [0,1)): probability that an element of ``x`` will be set to zero
name (:class:str, optional): the name of the Function instance in the network
name (:class:`str`, optional): the name of the Function instance in the network
Returns:
:class:`~cntk.ops.functions.Function`
@ -2545,7 +2546,7 @@ def output_variable(shape, dtype, dynamic_axes, name=''):
Args:
shape (tuple or int): the shape of the input tensor
dtype (type): np.float32 or np.float64
dtype (np.float32 or np.float64): data type
dynamic_axes (list or tuple): a list of dynamic axis (e.g., batch axis, time axis)
name (str, optional): the name of the Function instance in the network

Просмотреть файл

@ -65,9 +65,10 @@ class Function(cntk_py.Function):
``@Function`` constructs a Function from a Python lambda
where the Function's input signature is defined by the lambda.
Use this as a decorator, e.g.:
``@Function
def f(x): return x * x``
Use this as a decorator, e.g.::
@Function
def f(x): return x * x
The above form creates a CNTK Function whose arguments are placeholder variables.
Such a function can only be combined with others symbolic functions.
@ -76,19 +77,21 @@ class Function(cntk_py.Function):
of the arguments. In this case, the @Function decorator creates a CNTK Function
whose arguments are input variables.
If you use Python 3, Functions with types are declared using Python annotation syntax, e.g.:
``@Function
def f(x:Tensor[13]):
return x * x``
If you use Python 3, Functions with types are declared using Python annotation syntax, e.g.::
If you are still working with Python 2.7, use CNTK's @Signature decorator instead:
``@Function
@Function
def f(x:Tensor[13]):
return x * x
If you are working with Python 2.7, use CNTK's @Signature decorator instead::
@Function
@Signature(Tensor[13])
def f(x):
return x * x``
return x * x
``make_block=True`` is used to implement @BlockFunction(). If given the result will be wrapped
in ``as_block()``, using the supplied ``op_name`` and ``name`` parameters, which are otherwise ignored.
in ``as_block()``, using the supplied ``op_name`` and ``name`` parameters, which are otherwise ignored.
'''
f_name = f.__name__ # (only used for debugging and error messages)
@ -471,6 +474,7 @@ class Function(cntk_py.Function):
input data.
* any other type: if node has an unique input, arguments is
mapped to this input.
For nodes with more than one input, only dict is allowed.
In both cases, every sample in the data will be interpreted
@ -587,6 +591,7 @@ class Function(cntk_py.Function):
elements of the sequence are grouped along axis 0.
* any other type: if node has an unique input, arguments is
mapped to this input.
For nodes with more than one input, only dict is allowed.
In both cases, every sample in the data will be interpreted

Просмотреть файл

@ -96,11 +96,14 @@ class Trainer(cntk_py.Trainer):
Args:
arguments: maps variables to their input data. Empty map signifies
end of local training data.
end of local training data.
The interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
* any other type: if node has an unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
@ -179,8 +182,10 @@ class Trainer(cntk_py.Trainer):
* `dict`: keys are input variable or names, and values are the input data.
See :meth:`~cntk.ops.functions.Function.forward` for details on passing input data.
* any other type: if node has an unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the

Просмотреть файл

@ -157,7 +157,7 @@ class TrainingSession(cntk_py.TrainingSession):
Perform training on a specified device.
Args:
device (:class:~cntk.device.DeviceDescriptor): the device descriptor containing
device (:class:`~cntk.device.DeviceDescriptor`): the device descriptor containing
the type and id of the device where training takes place.
'''

Просмотреть файл

@ -204,9 +204,12 @@ def eval(op, arguments=None, precision=None, device=None, backward_pass=False, e
op (:class:`Function`): operation to evaluate
arguments: maps variables to their input data. The
interpretation depends on the input type:
* `dict`: keys are input variable or names, and values are the input data.
* any other type: if node has a unique input, ``arguments`` is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
@ -351,8 +354,9 @@ def Signature(*args, **kwargs):
**kwargs: types of arguments with optional names, e.g. `x=Tensor[42]`. Use this second form for
longer argument lists.
Example:
``# Python 3:
Example::
# Python 3:
@Function
def f(x: Tensor[42]):
return sigmoid(x)

Просмотреть файл

@ -1,17 +1,27 @@
REM Steps to recreate the docs:
setlocal
cd "%~dp0"
cd /d "%~dp0"
set PYTHONPATH=%CD%\..
echo PYTHONPATH=%PYTHONPATH%
set PATH=%CD%\..;%CD%\..\..\..\x64\Release;%PATH%
echo PATH=%PATH%
sphinx-apidoc.exe ..\cntk -o . -f
@REM TODO better align conf.py exclude with excluded paths here
sphinx-apidoc.exe ..\cntk -o . -f ^
..\cntk\tests ^
..\cntk\debugging\tests ^
..\cntk\internal\tests ^
..\cntk\io\tests ^
..\cntk\layers\tests ^
..\cntk\learners\tests ^
..\cntk\logging\tests ^
..\cntk\losses\tests ^
..\cntk\metrics\tests ^
..\cntk\ops\tests ^
..\cntk\train\tests ^
..\cntk\utils\tests
if errorlevel 1 exit /b 1
.\make.bat html
if errorlevel 1 exit /b 1
echo start _build\html\index.html

Просмотреть файл

@ -1,10 +0,0 @@
cntk.learner package
====================
Module contents
---------------
.. automodule:: cntk.learner
:members:
:undoc-members:
:show-inheritance:

Просмотреть файл

@ -1,10 +0,0 @@
cntk.trainer package
====================
Module contents
---------------
.. automodule:: cntk.trainer
:members:
:undoc-members:
:show-inheritance:

Просмотреть файл

@ -1,10 +0,0 @@
cntk.training_session package
=============================
Module contents
---------------
.. automodule:: cntk.training_session
:members:
:undoc-members:
:show-inheritance:

Просмотреть файл

@ -1,3 +1,5 @@
:orphan:
Concepts
========

Просмотреть файл

@ -146,7 +146,7 @@ html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
#html_static_path = ['_static']
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied

Просмотреть файл

@ -1,5 +1,5 @@
Debugging models
================
.. automodule:: cntk.debug
.. automodule:: cntk.debugging
:members:

Просмотреть файл

@ -1,13 +0,0 @@
Graph components
===========================
.. automodule:: cntk.ops.variables
:members:
:undoc-members:
:show-inheritance:
.. automodule:: cntk.ops.functions
:members:
:undoc-members:
:show-inheritance:

Просмотреть файл

@ -8,19 +8,19 @@ Tutorials
CNTK 102: `Feed Forward network`_ with NumPy
#. *Recognize hand written digits (OCR) with MNIST data*
CNTK 103 Part A: `Data preparation <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb>`_ , Part B: `Feed Forward classifier`_
CNTK 103 Part A: `MNIST data preparation`_ , Part B: `Feed Forward classifier`_
#. *Learn how to predict the stock market*
CNTK 104: `Time Series basics`_ with finance data
#. *Compress (using autoencoder) hand written digits from MNIST data with no human input (unsupervised learning, FFN)*
CNTK 105 Part A: `Data preparation <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb>`_ , Part B: `Feed Forward autoencoder`_
CNTK 105 Part A: `MNIST data preparation`_ , Part B: `Feed Forward autoencoder`_
#. *Forecasting using data from an IOT device*
CNTK 106: LSTM based forecasting - Part A: `with simulated data <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_106A_LSTM_Timeseries_with_Simulated_Data.ipynb>`_, Part B: `with real IOT data <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_106B_LSTM_Timeseries_with_IOT_Data.ipynb>`_
#. *Recognize objects in images from CIFAR-10 data (Convolutional Network, CNN)*
CNTK 201 Part A: `Data preparation <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_201A_CIFAR-10_DataLoader.ipynb>`_, Part B: `VGG and ResNet classifiers`_
CNTK 201 Part A: `CIFAR data preparation`_, Part B: `VGG and ResNet classifiers`_
#. *Infer meaning from text snippets using LSTMs and word embeddings*
CNTK 202: `Language understanding`_ with ATIS3 text data
@ -35,7 +35,7 @@ Tutorials
CNTK 205: `Artistic Style Transfer`_
#. *Produce realistic data (MNIST images) with no human input (unsupervised learning)*
CNTK 206 Part A: `Data preparation <https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb>`_ , Part B: `Basic Generative Adversarial Networks (GAN)`_
CNTK 206 Part A: `MNIST data preparation`_ , Part B: `Basic Generative Adversarial Networks (GAN)`_
#. *Training with Sampled Softmax*
CNTK 207: `Training with Sampled Softmax`_
@ -46,12 +46,12 @@ For our Japanese users, you can find some of the `tutorials in Japanese`_.
.. _`Logistic Regression`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_101_LogisticRegression.ipynb
.. _`Feed Forward network`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_102_FeedForward.ipynb
.. _`Data preparation`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb
.. _`MNIST data preparation`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103A_MNIST_DataLoader.ipynb
.. _`Feed Forward classifier`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_103B_MNIST_FeedForwardNetwork.ipynb
.. _`Time Series basics`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_104_Finance_Timeseries_Basic_with_Pandas_Numpy.ipynb
.. _`Feed Forward autoencoder`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_105_Basic_Autoencoder_for_Dimensionality_Reduction.ipynb
.. _`Basic LSTM based time series`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_106A_LSTM_Timeseries_with_Simulated_Data.ipynb
.. _`data preparation`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_201A_CIFAR-10_DataLoader.ipynb
.. _`CIFAR data preparation`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_201A_CIFAR-10_DataLoader.ipynb
.. _`VGG and ResNet classifiers`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_201B_CIFAR-10_ImageHandsOn.ipynb
.. _`Language understanding`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_202_Language_Understanding.ipynb
.. _`Reinforcement learning basics`: https://github.com/Microsoft/CNTK/blob/v2.0.beta12.0/Tutorials/CNTK_203_Reinforcement_Learning_Basics.ipynb