Added NDConvolution node. Added ConvolveGeometry.
This commit is contained in:
Родитель
65de62c934
Коммит
cee4cfa6f8
|
@ -0,0 +1,152 @@
|
|||
# Parameters can be overwritten on the command line
|
||||
# for example: cntk configFile=myConfigFile RootDir=../..
|
||||
# For running from Visual Studio add
|
||||
# currentDirectory=$(SolutionDir)/<path to corresponding data folder>
|
||||
RootDir = ".."
|
||||
|
||||
ConfigDir = "$RootDir$/Config"
|
||||
DataDir = "$RootDir$/Data"
|
||||
OutputDir = "$RootDir$/Output"
|
||||
ModelDir = "$OutputDir$/Models"
|
||||
|
||||
deviceId = 0
|
||||
imageLayout = "cudnn"
|
||||
# override the above as follows when running on CPU:
|
||||
# deviceId = -1
|
||||
# imageLayout = "legacy"
|
||||
|
||||
command = train:test
|
||||
|
||||
precision = "float"
|
||||
modelPath = "$ModelDir$/04_NDConvolution"
|
||||
|
||||
# uncomment the following line to write logs to a file
|
||||
stderr = "$OutputDir$/04_NDConvolution_out"
|
||||
traceLevel=1
|
||||
numMBsToShowResult=500
|
||||
|
||||
prefetch=true
|
||||
useCuDnn=true
|
||||
|
||||
#######################################
|
||||
# TRAINING CONFIG #
|
||||
#######################################
|
||||
|
||||
train = [
|
||||
action = "train"
|
||||
|
||||
BrainScriptNetworkBuilder = [
|
||||
|
||||
useCuDnn = $useCuDnn$
|
||||
|
||||
// macros
|
||||
NDConvReLULayer(inp, kW, kH, inMap, outMap, hStride, vStride, wScale, bValue) = [ // ReLU non-linearity
|
||||
convW = Parameter(outMap, kW * kH * inMap, init="uniform", initValueScale=wScale, initOnCPUOnly=false)
|
||||
//conv = NDConvolution(convW, inp, kernelShape=(kW : kH : inMap), mapShape=(1 : 1 : outMap), stride=(hStride : vStride : 1),
|
||||
// sharing=(true : true : true), padding=(true : true : false), imageLayout=if useCuDnn then "cudnn" else "legacy")
|
||||
//conv = NDConvolution(convW, inp, (kW : kH : inMap), (1 : 1 : outMap), (hStride : vStride : 1),
|
||||
// (1 : 1 : 1), (1 : 1 : 1), (1 : 1 : 1), (1 : 1 : 1), imageLayout=if useCuDnn then "cudnn" else "legacy")
|
||||
conv = NDConvolution(convW, inp, (kW : kH : inMap), (1 : 1 : outMap), stride=1, sharing=true, autoPadding=false, lowerPad=1, imageLayout=if useCuDnn then "cudnn" else "legacy")
|
||||
convB = if useCuDnn
|
||||
then ParameterTensor((1 : 1 : outMap), init="fixedValue", value=bValue)
|
||||
else Parameter(outMap, 1, init="fixedValue", value=bValue)
|
||||
convPlusB = Plus(conv, convB);
|
||||
out = RectifiedLinear(convPlusB);
|
||||
]
|
||||
|
||||
DNNSigmoidLayer(inDim, outDim, x, parmScale) = [ // Sigmoid non-linearity
|
||||
W = Parameter(outDim, inDim, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
|
||||
b = Parameter(outDim, 1, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
|
||||
t = Times(W, x)
|
||||
z = Plus(t, b)
|
||||
out = Sigmoid(z)
|
||||
]
|
||||
|
||||
DNNLayer(inDim, outDim, x, parmScale) = [ // no non-linearity, as input for SoftMax
|
||||
W = Parameter(outDim, inDim, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
|
||||
b = Parameter(outDim, 1, init="uniform", initValueScale=parmScale, initOnCPUOnly=false)
|
||||
t = Times(W, x)
|
||||
out = Plus(t, b)
|
||||
]
|
||||
|
||||
imageW = 28
|
||||
imageH = 28
|
||||
labelDim = 10
|
||||
|
||||
features = ImageInput(imageW, imageH, 1, imageLayout=if useCuDnn then "cudnn" else "legacy", tag="feature")
|
||||
featScale = Constant(0.00390625)
|
||||
featScaled = Scale(featScale, features)
|
||||
labels = Input(labelDim, tag="label")
|
||||
|
||||
# conv1
|
||||
kW1 = 5
|
||||
kH1 = 5
|
||||
cMap1 = 16
|
||||
hStride1 = 1
|
||||
vStride1 = 1
|
||||
# weight[cMap1, kW1 * kH1 * inputChannels]
|
||||
conv1_act = NDConvReLULayer(featScaled, kW1, kH1, 1, cMap1, hStride1, vStride1, 10, 1).out
|
||||
|
||||
h1Dim = 128
|
||||
# DNNSigmoidLayer and DNNLayer are defined in Macros.ndl
|
||||
h1 = DNNSigmoidLayer(12544, h1Dim, conv1_act, 1).out
|
||||
ol = DNNLayer(h1Dim, labelDim, h1, 1).out
|
||||
|
||||
ce = CrossEntropyWithSoftmax(labels, ol, tag="criterion")
|
||||
err = ErrorPrediction(labels, ol, tag="eval")
|
||||
outputNodes = ol
|
||||
]
|
||||
|
||||
SGD = [
|
||||
epochSize = 60000
|
||||
minibatchSize = 32
|
||||
learningRatesPerMB = 0.5
|
||||
momentumPerMB = 0*10:0.7
|
||||
maxEpochs = 15
|
||||
]
|
||||
|
||||
reader = [
|
||||
readerType = "UCIFastReader"
|
||||
# To get the data (Train-28x28.txt) please run `python mnist_convert.py`
|
||||
# from the 'AdditionalFiles' folder. See REAMDE.md for details.
|
||||
file = "$DataDir$/Train-28x28.txt"
|
||||
|
||||
features = [
|
||||
dim = 784
|
||||
start = 1
|
||||
]
|
||||
|
||||
labels = [
|
||||
dim = 1
|
||||
start = 0
|
||||
labelDim = 10
|
||||
labelMappingFile = "$DataDir$/labelsmap.txt"
|
||||
]
|
||||
]
|
||||
]
|
||||
|
||||
#######################################
|
||||
# TEST CONFIG #
|
||||
#######################################
|
||||
|
||||
test = [
|
||||
action = test
|
||||
minibatchSize = 16
|
||||
|
||||
reader = [
|
||||
readerType = "UCIFastReader"
|
||||
file = "$DataDir$/Test-28x28.txt"
|
||||
|
||||
features = [
|
||||
dim = 784
|
||||
start = 1
|
||||
]
|
||||
|
||||
labels = [
|
||||
dim = 1
|
||||
start = 0
|
||||
labelDim = 10
|
||||
labelMappingFile = "$DataDir$/labelsmap.txt"
|
||||
]
|
||||
]
|
||||
]
|
|
@ -63,6 +63,9 @@ L"ParameterTensor(dims, learningRateMultiplier = 1.0, init = 'uniform'/*|fixedVa
|
|||
L"WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability : instanceWeight) /*plus the function args*/ ]\n"
|
||||
L"ReconcileMBLayout(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileMBLayout' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]\n"
|
||||
L"Convolution(weightNode, inputValueNode, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, zeroPadding = false, maxTempMemSizeInSamples = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
|
||||
L"NDConvolution(weightNode, inputValueNode, kernelDims, mapDims, stride=1, sharing = true, autoPadding = true, lowerPad = 0, upperPad = 0, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode); kernelShape = new TensorShape [ dims = kernelDims ] ; mapShape = new TensorShape [ dims = mapDims ] ; strideShape = new TensorShape [ dims = stride ] ; dimSharing = new BoolVector [ items = sharing ] ; dimPadding = new BoolVector [ items = autoPadding ] ; dimPadLower = new TensorShape [ dims = lowerPad ] ; dimPadUpper = new TensorShape [ dims = upperPad ] /*plus the function args*/ ]\n"
|
||||
//L"NDConvolution(weightNode, inputValueNode, kernelShape = new TensorShape [ /*dims*/ ], mapShape = new TensorShape[ /*dims*/ ], stride = new TensorShape [ /*dims*/ ], sharing = new BoolVector [ /*dims*/ ], padding = new BoolVector [ /*dims*/ ], lowerPad = new TensorShape [ /*dims*/ ], upperPad = new TensorShape [ /*dims*/ ], imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
|
||||
//L"NDConvolution(weightNode, inputValueNode, kernelShape, mapShape, stride, sharing, padding, lowerPad, upperPad, imageLayout='CHW', maxTempMemSizeInSamples = 0, tag='') = new ComputationNode [ operation = 'NDConvolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
|
||||
L"MaxPooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxPooling' ; inputs = input /*plus the function args*/ ]\n"
|
||||
L"AveragePooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'AveragePooling' ; inputs = input /*plus the function args*/ ]\n"
|
||||
// TODO: define DelayedValue, with negative delay for future; cannot do this yet, need to be able to say something like delay = -(^.delay)
|
||||
|
|
|
@ -127,6 +127,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateNode(const std::wstring& node
|
|||
if (nodeType == OperationNameOf(AveragePoolingNode)) return New<AveragePoolingNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(BatchNormalizationNode)) return New<BatchNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ConvolutionNode)) return New<ConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(NDConvolutionNode)) return New<NDConvolutionNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(SparseInputValue)) return New<SparseInputValue<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(InputValue)) return New<InputValue<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(LearnableParameter)) return New<LearnableParameter<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
|
|
@ -563,5 +563,6 @@ public:
|
|||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedTensorShape> registerTensorShape(L"TensorShape");
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<int>> registerIntVector(L"IntVector");
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<size_t>> registerSizeVector(L"SizeVector");
|
||||
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<BoxedVector<bool>> registerBoolVector(L"BoolVector");
|
||||
|
||||
}}}
|
||||
|
|
|
@ -25,6 +25,157 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
class NDConvolutionNode : public ComputationNode<ElemType>, public NumInputs<2>
|
||||
{
|
||||
typedef ComputationNode<ElemType> Base;
|
||||
UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName()
|
||||
{
|
||||
return L"NDConvolution";
|
||||
}
|
||||
|
||||
public:
|
||||
NDConvolutionNode(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
: Base(deviceId, name)
|
||||
{
|
||||
}
|
||||
NDConvolutionNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& kernelShape, const TensorShape& mapShape, const TensorShape& strideShape,
|
||||
const std::vector<bool>& sharing, const std::vector<bool>& autoPadding, const TensorShape& lowerPad, const TensorShape& upperPad,
|
||||
ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples)
|
||||
: Base(deviceId, name), m_kernel(kernelShape), m_stride(strideShape), m_sharing(sharing),
|
||||
m_autoPad(autoPadding), m_lowerPad(lowerPad), m_upperPad(upperPad)
|
||||
{
|
||||
}
|
||||
NDConvolutionNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: NDConvolutionNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"kernelShape"), configp->Get(L"mapShape"), configp->Get(L"strideShape"),
|
||||
configp->Get(L"dimSharing"), configp->Get(L"dimPadding"), configp->Get(L"dimPadLower"), configp->Get(L"dimPadUpper"),
|
||||
ImageLayoutKindFrom(configp->Get(L"imageLayout")), configp->Get(L"maxTempMemSizeInSamples"))
|
||||
{
|
||||
AttachInputs(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
public:
|
||||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
}
|
||||
|
||||
void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
Base::CopyTo(nodeP, newName, flags);
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
void BackpropTo(const size_t inputIndex, const FrameRange& fr) override
|
||||
{
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
void ForwardProp(const FrameRange& fr) override
|
||||
{
|
||||
}
|
||||
|
||||
void Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
Base::Validate(isFinalValidationPass);
|
||||
InferMBLayoutFromInputsForStandardCase();
|
||||
|
||||
bool validate = isFinalValidationPass;
|
||||
if (validate && m_imageLayoutKind != ImageLayoutKind::CHW)
|
||||
{
|
||||
InvalidArgument(
|
||||
"NDConvolution supports only cuDNN (CHW) data layout. "
|
||||
"Please specify imageLayout=\"cudnn\" in NDConvolution node in your BrainScript "
|
||||
"and make sure input data layout is CHW");
|
||||
}
|
||||
|
||||
auto dimsInput = GetInputSampleLayout(1);
|
||||
if (validate)
|
||||
{
|
||||
if (dimsInput.GetRank() != m_kernel.GetRank())
|
||||
InvalidArgument("Convolution input and kernel tensors must have the same rank.");
|
||||
if (m_stride.GetRank() != 1 && dimsInput.GetRank() != m_stride.GetRank())
|
||||
InvalidArgument("Convolution stride tensor must have rank 1 or the same as the input tensor.");
|
||||
if (m_sharing.size() != 1 && dimsInput.GetRank() != m_sharing.size())
|
||||
InvalidArgument("Convolution sharing tensor must have rank 1 or the same as the input tensor.");
|
||||
if (m_autoPad.size() != 1 && dimsInput.GetRank() != m_autoPad.size())
|
||||
InvalidArgument("Convolution padding tensor must have rank 1 or the same as the input tensor.");
|
||||
if (m_lowerPad.GetRank() != 1 && dimsInput.GetRank() != m_lowerPad.GetRank())
|
||||
InvalidArgument("Convolution lower pad tensor must have rank 1 or the same as the input tensor.");
|
||||
if (m_upperPad.GetRank() != 1 && dimsInput.GetRank() != m_upperPad.GetRank())
|
||||
InvalidArgument("Convolution upper pad tensor must have rank 1 or the same as the input tensor.");
|
||||
}
|
||||
|
||||
SmallVector<size_t> dimsOutput(dimsInput.GetRank());
|
||||
for (size_t i = 0; i < dimsInput.GetRank(); i++)
|
||||
{
|
||||
assert(dimsInput[i] >= 1);
|
||||
if (validate && m_kernel[i] > dimsInput[i])
|
||||
InvalidArgument("NDConvolution operation requires that kernel dim %d <= input dim %d.", (int)m_kernel[i], (int)dimsInput[i]);
|
||||
|
||||
size_t delta = m_stride[m_stride.GetRank() == 1 ? 0 : i];
|
||||
if (validate && delta > m_kernel[i])
|
||||
InvalidArgument("NDConvolution operation requires that stride %d <= input dim %d.", (int)m_stride[i], (int)dimsInput[i]);
|
||||
|
||||
size_t dim = dimsInput[i];
|
||||
bool autoPad = m_autoPad[m_autoPad.size() == 1 ? 0 : i];
|
||||
if (autoPad)
|
||||
{
|
||||
dim -= 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t lo = m_lowerPad[m_lowerPad.size() == 1 ? 0 : i];
|
||||
size_t hi = m_upperPad[m_upperPad.size() == 1 ? 0 : i];
|
||||
dim += lo + hi;
|
||||
}
|
||||
}
|
||||
|
||||
SetDims(TensorShape(dimsOutput), true);
|
||||
|
||||
if (isFinalValidationPass)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) override
|
||||
{
|
||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||
}
|
||||
|
||||
void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
}
|
||||
|
||||
void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
}
|
||||
|
||||
private:
|
||||
ImageLayoutKind m_imageLayoutKind;
|
||||
|
||||
TensorShape m_kernel;
|
||||
TensorShape m_stride;
|
||||
std::vector<bool> m_sharing;
|
||||
std::vector<bool> m_autoPad;
|
||||
TensorShape m_lowerPad;
|
||||
TensorShape m_upperPad;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ConvolutionNode (convolutionWeights, inputFeature)
|
||||
// -----------------------------------------------------------------------
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "Matrix.h"
|
||||
#include "TensorShape.h" // for ImageLayoutKind
|
||||
#include "ConvolveGeometry.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "Basics.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// Notes:
|
||||
// * ConvolveGeometry represents the application of one or more rectangular "kernels" (all of the same size)
|
||||
// to a rectangular input to produce a rectangular output.
|
||||
// * A "cell" in the rectangular input is identified by a single coordinate called a "col" (for column).
|
||||
// * A "cell" in the rectangular output is identified by a single coordinate called a "row".
|
||||
// * The kernels may involve weights, in which case MpRowIwht indicates the starting index of the weights
|
||||
// used for a given output cell.
|
||||
// The overall idea of ConvolveGeometry is to precompute maps that can be used to apply convolutions of
|
||||
// arbitrary configuration and dimension. In such case the generic implementation becomes very simple and invariant
|
||||
// wrt convolution configuration and dimensionality. For specific cases like 2D convolutions and full sharing,
|
||||
// highly optimized implementations (e.g. cuDNN) are used.
|
||||
class ConvolveGeometry final
|
||||
{
|
||||
public:
|
||||
using IntVec = std::vector<int>;
|
||||
|
||||
// Maps from a "row" (index of output cell) to its base "col" (index of input cell). For a given row,
|
||||
// the cols that contribute to it are { MpRowCol[row] + Indices[i0 + 1 + i] | 0 <= i < Indices[i0] },
|
||||
// where i0 = MpRowIndices[row].
|
||||
const IntVec& MpRowCol() const { return m_mpRowCol; }
|
||||
|
||||
// Maps from a "row" (index of output cell) to where to start in the weights array. Each run of weights
|
||||
// consists of KernelSize weights.
|
||||
const IntVec& MpRowIwht() const { return m_mpRowIwht; }
|
||||
|
||||
// Maps from a "row" (index of output cell) to its starting index in Runs. A run consists of:
|
||||
// * skip count (to skip that many weights)
|
||||
// * item count
|
||||
// * relative indices into source (item count of these)
|
||||
// * masks (all 1's or all 0's) (item count of these)
|
||||
// For items that are masked out (0 mask), the index stored is the next valid index.
|
||||
// This ensures that accessing the corresponding neuron value doesn't fault and that
|
||||
// backprop operations write the correct value last (any previous writes won't change
|
||||
// the value).
|
||||
// NOTE: The first (zeroth) run is always the "full" kernel run. Also, MpRowRun can be empty,
|
||||
// indicating that all values are zero (all outputs use the "full" kernel run).
|
||||
const IntVec& MpRowRun() const { return m_mpRowRun; }
|
||||
const IntVec& Runs() const { return m_runs; }
|
||||
|
||||
// Maps from a "row" (index of output cell) to its starting index in Indices. Note that "Runs" is intended
|
||||
// for kernels that have weights, while "Indices" is intended for kernels that don't need to access weights.
|
||||
// As a result, the encoding in Indices is simpler and more direct.
|
||||
// A run in Indices consists of:
|
||||
// * item count
|
||||
// * relative indices into source (item count of these)
|
||||
// NOTE: The first run of indices is always the "full" kernel run. Also, MpRowIndices can be empty,
|
||||
// indicating that all values are zero (all outputs use the "full" kernel run).
|
||||
const IntVec& MpRowIndices() const { return m_mpRowIndices; }
|
||||
const IntVec& Indices() const { return m_indices; }
|
||||
|
||||
// The indices of the first ("top-left-most") "kernel-center" cell in the source.
|
||||
const IntVec& Start() const { return m_start; }
|
||||
int StartIndex() const { return m_startIndex; }
|
||||
|
||||
ConvolveGeometry(const TensorShape& input, const TensorShape& kernel)
|
||||
{
|
||||
assert(input.GetRank() == kernel.GetRank());
|
||||
}
|
||||
|
||||
private:
|
||||
IntVec m_mpRowCol;
|
||||
IntVec m_mpRowIwht;
|
||||
IntVec m_mpRowRun;
|
||||
IntVec m_runs;
|
||||
IntVec m_mpRowIndices;
|
||||
IntVec m_indices;
|
||||
IntVec m_start;
|
||||
int m_startIndex;
|
||||
};
|
||||
|
||||
} } }
|
|
@ -158,6 +158,7 @@
|
|||
<ClInclude Include="..\Common\Include\fileutil.h" />
|
||||
<ClInclude Include="CommonMatrix.h" />
|
||||
<ClInclude Include="ConvolutionEngine.h" />
|
||||
<ClInclude Include="ConvolveGeometry.h" />
|
||||
<ClInclude Include="CPUMatrix.h" />
|
||||
<ClInclude Include="MatrixQuantizerImpl.h" />
|
||||
<ClInclude Include="TensorOps.h" />
|
||||
|
|
|
@ -97,6 +97,9 @@
|
|||
<ClInclude Include="MatrixQuantizerImpl.h">
|
||||
<Filter>1bitSGD</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="ConvolveGeometry.h">
|
||||
<Filter>Convolution</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="GPUMatrix.h">
|
||||
|
|
Загрузка…
Ссылка в новой задаче