not expose unused NDArrayViewConstructor

This commit is contained in:
Zhou Wang 2017-03-16 15:10:52 +01:00
Родитель a0e3ebfcdb
Коммит 59910e0d34
2 изменённых файлов: 32 добавлений и 35 удалений

Просмотреть файл

@ -2158,7 +2158,6 @@ namespace CNTK
return Create(sampleShape, sequences, {}, device, readOnly);
}
///
/// Create a new Value object containing a collection of variable length sequences.
///

Просмотреть файл

@ -1096,7 +1096,7 @@
// Is CreateBatch for OneHot really useful?
var input = new System.Collections.Generic.List<System.Collections.Generic.List<uint>>();
batch.ForEach(element => input.Add(new System.Collections.Generic.List<uint>(1) {element}));
return Create<T>(dimension, input, new System.Collections.Generic.List<bool>(0), device, readOnly);
}
@ -1135,6 +1135,36 @@
return Create<T>(dimension, batchOfSequences, sequenceStartFlags, device, readOnly);
}
private static Value Create<T>(uint dimension,
System.Collections.Generic.List<System.Collections.Generic.List<uint>> sequences,
System.Collections.Generic.List<bool> sequenceStartFlags,
DeviceDescriptor device,
bool readOnly = false)
{
var seqFlags = new BoolVector(sequenceStartFlags);
var inputSeqVector = new SizeTVectorVector();
var sizeTVectorRefList = new System.Collections.Generic.List<SizeTVector>();
foreach (var seq in sequences)
{
var s = new SizeTVector(seq);
sizeTVectorRefList.Add(s);
inputSeqVector.Add(s);
}
if (typeof(T).Equals(typeof(float)))
{
return Value.CreateOneHotFloat(dimension, inputSeqVector, seqFlags, device, readOnly);
}
else if (typeof(T).Equals(typeof(double)))
{
return Value.CreateOneHotDouble(dimension, inputSeqVector, seqFlags, device, readOnly);
}
else
{
throw new System.ArgumentException("The data type " + typeof(T).ToString() + " is not supported. Only float or double is supported by CNTK.");
}
}
// Create Value object from sparse input
public static Value CreateSequence<T>(NDShape sampleShape, uint sequenceLength,
int[] colStarts, int[] rowIndices, T[] nonZeroValues, uint numNonZeroValues,
bool sequenceStartFlag,
@ -1191,35 +1221,6 @@
return Value.CreateSequence<T>(dimension, sequenceLength, colStarts, rowIndices, nonZeroValues, numNonZeroValues, true, device, readOnly);
}
private static Value Create<T>(uint dimension,
System.Collections.Generic.List<System.Collections.Generic.List<uint>> sequences,
System.Collections.Generic.List<bool> sequenceStartFlags,
DeviceDescriptor device,
bool readOnly = false)
{
var seqFlags = new BoolVector(sequenceStartFlags);
var inputSeqVector = new SizeTVectorVector();
var sizeTVectorRefList = new System.Collections.Generic.List<SizeTVector>();
foreach (var seq in sequences)
{
var s = new SizeTVector(seq);
sizeTVectorRefList.Add(s);
inputSeqVector.Add(s);
}
if (typeof(T).Equals(typeof(float)))
{
return Value.CreateOneHotFloat(dimension, inputSeqVector, seqFlags, device, readOnly);
}
else if (typeof(T).Equals(typeof(double)))
{
return Value.CreateOneHotDouble(dimension, inputSeqVector, seqFlags, device, readOnly);
}
else
{
throw new System.ArgumentException("The data type " + typeof(T).ToString() + " is not supported. Only float or double is supported by CNTK.");
}
}
// Create value object from NDArrayView
public static Value Create(NDShape sampleShape,
System.Collections.Generic.List<NDArrayView> sequences,
@ -1320,9 +1321,6 @@
}
return;
}
%}
%extend CNTK::Value {
@ -1344,7 +1342,7 @@
//
%ignore CNTK::NDArrayView::NDArrayView(::CNTK::DataType dataType, const NDShape& viewShape, void* dataBuffer, size_t bufferSizeInBytes, const DeviceDescriptor& device, bool readOnly = false);
%ignore CNTK::NDArrayView::NDArrayView(::CNTK::DataType dataType, const NDShape& viewShape, const void* dataBuffer, size_t bufferSizeInBytes, const DeviceDescriptor& device);
%ignore CNTK::NDArrayView::NDArrayView(double value, DataType dataType = DataType::Float, const NDShape& viewShape = { 1 }, const DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice(), bool readOnly = false);
%extend CNTK::NDArrayView {
NDArrayView(const NDShape& viewShape, float *dataBuffer, size_t numBufferElements, const DeviceDescriptor& device, bool readOnly = false)