Refactoring to use transforms
This commit is contained in:
Родитель
ecdd9ae99f
Коммит
6cd863e228
|
@ -1,3 +1,9 @@
|
|||
//
|
||||
// <copyright company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#define DATAREADER_EXPORTS // creating the exports here
|
||||
#include "DataReader.h"
|
||||
|
@ -6,13 +12,226 @@
|
|||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <locale>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
//-------------------
|
||||
// Transforms
|
||||
|
||||
class ITransform
|
||||
{
|
||||
public:
|
||||
virtual void Init(const ConfigParameters& config) = 0;
|
||||
virtual void Apply(cv::Mat& mat) = 0;
|
||||
|
||||
ITransform() {};
|
||||
virtual ~ITransform() {};
|
||||
public:
|
||||
ITransform(const ITransform&) = delete;
|
||||
ITransform& operator=(const ITransform&) = delete;
|
||||
ITransform(ITransform&&) = delete;
|
||||
ITransform& operator=(ITransform&&) = delete;
|
||||
};
|
||||
|
||||
class CropTransform : public ITransform
|
||||
{
|
||||
public:
|
||||
CropTransform(unsigned int seed) : m_rng(seed), m_rndUniInt(0, INT_MAX)
|
||||
{
|
||||
}
|
||||
|
||||
void Init(const ConfigParameters& config)
|
||||
{
|
||||
m_cropType = ParseCropType(config("cropType", ""));
|
||||
m_cropRatio = std::stof(config("cropRatio", "1"));
|
||||
if (!(0 < m_cropRatio && m_cropRatio <= 1.0f))
|
||||
RuntimeError("Invalid cropRatio value: %f.", m_cropRatio);
|
||||
if (!config.ExistsCurrent("hflip"))
|
||||
m_hFlip = m_cropType == CropType::Random;
|
||||
else
|
||||
m_hFlip = std::stoi(config("hflip")) != 0;
|
||||
}
|
||||
|
||||
void Apply(cv::Mat& mat)
|
||||
{
|
||||
mat = mat(GetCropRect(m_cropType, mat.rows, mat.cols, m_cropRatio));
|
||||
if (m_hFlip && (m_rndUniInt(m_rng) % 2) != 0)
|
||||
cv::flip(mat, mat, 1);
|
||||
}
|
||||
|
||||
private:
|
||||
enum class CropType { Center = 0, Random = 1 };
|
||||
|
||||
CropType ParseCropType(const std::string& src)
|
||||
{
|
||||
auto AreEqual = [](const std::string& s1, const std::string& s2) -> bool
|
||||
{
|
||||
return std::equal(s1.begin(), s1.end(), s2.begin(), [](const char& a, const char& b) { return std::tolower(a) == std::tolower(b); });
|
||||
};
|
||||
|
||||
if (src.empty() || AreEqual(src, "center"))
|
||||
return CropType::Center;
|
||||
if (AreEqual(src, "random"))
|
||||
return CropType::Random;
|
||||
|
||||
RuntimeError("Invalid crop type: %s.", src.c_str());
|
||||
}
|
||||
|
||||
cv::Rect GetCropRect(CropType type, int crow, int ccol, float cropRatio)
|
||||
{
|
||||
assert(crow > 0);
|
||||
assert(ccol > 0);
|
||||
assert(0 < cropRatio && cropRatio <= 1.0f);
|
||||
|
||||
int cropSize = static_cast<int>(std::min(crow, ccol) * cropRatio);
|
||||
int xOff = -1;
|
||||
int yOff = -1;
|
||||
|
||||
switch (type)
|
||||
{
|
||||
case CropType::Center:
|
||||
xOff = (ccol - cropSize) / 2;
|
||||
yOff = (crow - cropSize) / 2;
|
||||
break;
|
||||
case CropType::Random:
|
||||
xOff = m_rndUniInt(m_rng) % std::max(ccol - cropSize, 1);
|
||||
yOff = m_rndUniInt(m_rng) % std::max(crow - cropSize, 1);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
assert(0 <= xOff && xOff <= ccol - cropSize);
|
||||
assert(0 <= yOff && yOff <= crow - cropSize);
|
||||
return cv::Rect(xOff, yOff, cropSize, cropSize);
|
||||
}
|
||||
|
||||
private:
|
||||
std::default_random_engine m_rng;
|
||||
std::uniform_int_distribution<int> m_rndUniInt;
|
||||
|
||||
CropType m_cropType;
|
||||
float m_cropRatio;
|
||||
bool m_hFlip;
|
||||
};
|
||||
|
||||
class ScaleTransform : public ITransform
|
||||
{
|
||||
public:
|
||||
ScaleTransform(int dataType, unsigned int seed) : m_dataType(dataType), m_rng(seed), m_rndUniInt(0, INT_MAX)
|
||||
{
|
||||
assert(m_dataType == CV_32F || m_dataType == CV_64F);
|
||||
|
||||
m_interpMap.emplace("nearest", cv::INTER_NEAREST);
|
||||
m_interpMap.emplace("linear", cv::INTER_LINEAR);
|
||||
m_interpMap.emplace("cubic", cv::INTER_CUBIC);
|
||||
m_interpMap.emplace("lanczos", cv::INTER_LANCZOS4);
|
||||
}
|
||||
|
||||
void Init(const ConfigParameters& config)
|
||||
{
|
||||
m_imgWidth = config("width");
|
||||
m_imgHeight = config("height");
|
||||
m_imgChannels = config("channels");
|
||||
size_t cfeat = m_imgWidth * m_imgHeight * m_imgChannels;
|
||||
if (cfeat == 0 || cfeat > std::numeric_limits<size_t>().max() / 2)
|
||||
RuntimeError("Invalid image dimensions.");
|
||||
|
||||
m_interp.clear();
|
||||
std::stringstream ss{ config("interpolations", "") };
|
||||
for (std::string token = ""; std::getline(ss, token, ':');)
|
||||
{
|
||||
std::transform(token.begin(), token.end(), token.begin(), std::tolower);
|
||||
StrToIntMapT::const_iterator res = m_interpMap.find(token);
|
||||
if (res != m_interpMap.end())
|
||||
m_interp.push_back((*res).second);
|
||||
}
|
||||
|
||||
if (m_interp.size() == 0)
|
||||
m_interp.push_back(cv::INTER_LINEAR);
|
||||
}
|
||||
|
||||
void Apply(cv::Mat& mat)
|
||||
{
|
||||
// If matrix has not been converted to the right type, do it now as rescaling requires floating point type.
|
||||
if (mat.type() != m_dataType)
|
||||
mat.convertTo(mat, m_dataType);
|
||||
|
||||
assert(m_interp.size() > 0);
|
||||
cv::resize(mat, mat, cv::Size(static_cast<int>(m_imgWidth), static_cast<int>(m_imgHeight)), 0, 0,
|
||||
m_interp[m_rndUniInt(m_rng) % m_interp.size()]);
|
||||
}
|
||||
|
||||
private:
|
||||
std::default_random_engine m_rng;
|
||||
std::uniform_int_distribution<int> m_rndUniInt;
|
||||
|
||||
int m_dataType;
|
||||
|
||||
using StrToIntMapT = std::unordered_map<std::string, int>;
|
||||
StrToIntMapT m_interpMap;
|
||||
std::vector<int> m_interp;
|
||||
|
||||
size_t m_imgWidth;
|
||||
size_t m_imgHeight;
|
||||
size_t m_imgChannels;
|
||||
};
|
||||
|
||||
class MeanTransform : public ITransform
|
||||
{
|
||||
public:
|
||||
MeanTransform()
|
||||
{
|
||||
}
|
||||
|
||||
void Init(const ConfigParameters& config)
|
||||
{
|
||||
m_meanFile = config(L"meanFile", L"");
|
||||
if (!m_meanFile.empty())
|
||||
{
|
||||
cv::FileStorage fs;
|
||||
// REVIEW alexeyk: this sort of defeats the purpose of using wstring at all...
|
||||
auto fname = msra::strfun::utf8(m_meanFile);
|
||||
fs.open(fname, cv::FileStorage::READ);
|
||||
if (!fs.isOpened())
|
||||
RuntimeError("Could not open file: " + fname);
|
||||
fs["MeanImg"] >> m_meanImg;
|
||||
int cchan;
|
||||
fs["Channel"] >> cchan;
|
||||
int crow;
|
||||
fs["Row"] >> crow;
|
||||
int ccol;
|
||||
fs["Col"] >> ccol;
|
||||
if (cchan * crow * ccol != m_meanImg.channels() * m_meanImg.rows * m_meanImg.cols)
|
||||
RuntimeError("Invalid data in file: " + fname);
|
||||
fs.release();
|
||||
m_meanImg = m_meanImg.reshape(cchan, crow);
|
||||
}
|
||||
}
|
||||
|
||||
void Apply(cv::Mat& mat)
|
||||
{
|
||||
assert(m_meanImg.size() == cv::Size(0, 0) || (m_meanImg.size() == mat.size() && m_meanImg.channels()));
|
||||
|
||||
// REVIEW alexeyk: check type conversion (float/double).
|
||||
if (m_meanImg.size() == mat.size())
|
||||
mat = mat - m_meanImg;
|
||||
}
|
||||
|
||||
private:
|
||||
std::wstring m_meanFile;
|
||||
cv::Mat m_meanImg;
|
||||
};
|
||||
|
||||
//-------------------
|
||||
// ImageReader
|
||||
|
||||
template<class ElemType>
|
||||
ImageReader<ElemType>::ImageReader() : m_seed(0), m_rng(m_seed), m_rndUniInt(0, INT_MAX)
|
||||
{
|
||||
m_transforms.push_back(std::make_unique<CropTransform>(m_seed));
|
||||
m_transforms.push_back(std::make_unique<ScaleTransform>(sizeof(ElemType) == 4 ? CV_32F : CV_64F, m_seed));
|
||||
m_transforms.push_back(std::make_unique<MeanTransform>());
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
|
@ -40,12 +259,9 @@ void ImageReader<ElemType>::Init(const ConfigParameters& config)
|
|||
m_imgHeight = featSect.second("height");
|
||||
m_imgChannels = featSect.second("channels");
|
||||
m_featDim = m_imgWidth * m_imgHeight * m_imgChannels;
|
||||
m_meanFile = featSect.second(L"meanFile", L"");
|
||||
|
||||
m_cropType = ParseCropType(featSect.second("cropType", ""));
|
||||
m_cropRatio = std::stof(featSect.second("cropRatio", "1"));
|
||||
if (!(0 < m_cropRatio && m_cropRatio <= 1.0f))
|
||||
RuntimeError("Invalid cropRatio value: %f.", m_cropRatio);
|
||||
for (auto& t: m_transforms)
|
||||
t->Init(featSect.second);
|
||||
|
||||
SectionT labSect{ gettter("labelDim") };
|
||||
m_labName = msra::strfun::utf16(labSect.first);
|
||||
|
@ -124,9 +340,11 @@ bool ImageReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix<ElemType>
|
|||
{
|
||||
const auto& p = files[i + m_mbStart];
|
||||
auto img = cv::imread(p.first, cv::IMREAD_COLOR);
|
||||
for (auto& t: m_transforms)
|
||||
t->Apply(img);
|
||||
// Crop
|
||||
cv::Mat cropped;
|
||||
CropTransform(img, cropped);
|
||||
//cv::Mat cropped;
|
||||
//CropTransform(img, cropped);
|
||||
//int w = img.cols;
|
||||
//int h = img.rows;
|
||||
//int cropSize = std::min(w, h);
|
||||
|
@ -134,9 +352,13 @@ bool ImageReader<ElemType>::GetMinibatch(std::map<std::wstring, Matrix<ElemType>
|
|||
//int yOff = (h - cropSize) / 2;
|
||||
//cv::Mat cropped{ img(cv::Rect(xOff, yOff, cropSize, cropSize)) };
|
||||
|
||||
cropped.convertTo(img, CV_32F);
|
||||
// Scale
|
||||
cv::resize(img, img, cv::Size(static_cast<int>(m_imgWidth), static_cast<int>(m_imgHeight)), 0, 0, cv::INTER_LINEAR);
|
||||
//cropped.convertTo(img, CV_32F);
|
||||
//img.convertTo(img, CV_32F);
|
||||
//// Scale
|
||||
//cv::resize(img, img, cv::Size(static_cast<int>(m_imgWidth), static_cast<int>(m_imgHeight)), 0, 0, cv::INTER_LINEAR);
|
||||
|
||||
// Subtract mean
|
||||
//SubMeanTransform(img, img);
|
||||
|
||||
assert(img.isContinuous());
|
||||
auto data = reinterpret_cast<ElemType*>(img.ptr());
|
||||
|
@ -181,64 +403,6 @@ void ImageReader<ElemType>::SetRandomSeed(unsigned int seed)
|
|||
m_rng.seed(m_seed);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
typename ImageReader<ElemType>::CropType ImageReader<ElemType>::ParseCropType(const std::string& src)
|
||||
{
|
||||
auto AreEqual = [](const std::string& s1, const std::string& s2) -> bool
|
||||
{
|
||||
return std::equal(s1.begin(), s1.end(), s2.begin(), [](const char& a, const char& b) { return std::tolower(a) == std::tolower(b); });
|
||||
};
|
||||
|
||||
if (src.empty() || AreEqual(src, "center"))
|
||||
return CropType::Center;
|
||||
if (AreEqual(src, "random"))
|
||||
return CropType::Random;
|
||||
|
||||
RuntimeError("Invalid crop type: %s.", src.c_str());
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
cv::Rect ImageReader<ElemType>::GetCropRect(CropType type, int crow, int ccol, float cropRatio)
|
||||
{
|
||||
assert(crow > 0);
|
||||
assert(ccol > 0);
|
||||
assert(0 < cropRatio && cropRatio <= 1.0f);
|
||||
|
||||
int cropSize = static_cast<int>(std::min(crow, ccol) * cropRatio);
|
||||
int xOff = -1;
|
||||
int yOff = -1;
|
||||
|
||||
switch (type)
|
||||
{
|
||||
case CropType::Center:
|
||||
xOff = (ccol - cropSize) / 2;
|
||||
yOff = (crow - cropSize) / 2;
|
||||
break;
|
||||
case CropType::Random:
|
||||
xOff = m_rndUniInt(m_rng) % (ccol - cropSize);
|
||||
yOff = m_rndUniInt(m_rng) % (crow - cropSize);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
assert(0 <= xOff && xOff <= ccol - cropSize);
|
||||
assert(0 <= yOff && yOff <= crow - cropSize);
|
||||
return cv::Rect(xOff, yOff, cropSize, cropSize);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void ImageReader<ElemType>::CropTransform(const cv::Mat& src, cv::Mat& dst)
|
||||
{
|
||||
// REVIEW alexeyk: optimize resizing?
|
||||
dst = src(GetCropRect(m_cropType, src.rows, src.cols, m_cropRatio)).clone();
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void ImageReader<ElemType>::SubMeanTransform(const cv::Mat& , cv::Mat& )
|
||||
{
|
||||
}
|
||||
|
||||
template class ImageReader<double>;
|
||||
template class ImageReader<float>;
|
||||
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
//
|
||||
// <copyright file="UCIFastReader.h" company="Microsoft">
|
||||
// <copyright company="Microsoft">
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// </copyright>
|
||||
//
|
||||
// ImageReader.h - Include file for the image reader
|
||||
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <memory>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "DataReader.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// REVIEW alexeyk: can't put it into ImageReader itself as ImageReader is a template.
|
||||
class ITransform;
|
||||
|
||||
template<class ElemType>
|
||||
class ImageReader : public IDataReader<ElemType>
|
||||
{
|
||||
|
@ -36,13 +39,7 @@ public:
|
|||
void SetRandomSeed(unsigned int seed) override;
|
||||
|
||||
private:
|
||||
enum class CropType { Center = 0, Random = 1 };
|
||||
|
||||
CropType ParseCropType(const std::string& src);
|
||||
cv::Rect GetCropRect(CropType type, int crow, int ccol, float cropRatio);
|
||||
void CropTransform(const cv::Mat& src, cv::Mat& dst);
|
||||
|
||||
void SubMeanTransform(const cv::Mat& src, cv::Mat& dst);
|
||||
std::vector<std::unique_ptr<ITransform>> m_transforms;
|
||||
|
||||
private:
|
||||
std::default_random_engine m_rng;
|
||||
|
@ -70,10 +67,5 @@ private:
|
|||
std::vector<ElemType> m_labBuf;
|
||||
|
||||
unsigned int m_seed;
|
||||
|
||||
CropType m_cropType;
|
||||
float m_cropRatio;
|
||||
|
||||
std::wstring m_meanFile;
|
||||
};
|
||||
}}}
|
||||
|
|
|
@ -22,12 +22,12 @@ Train=[
|
|||
SGD=[
|
||||
epochSize=0
|
||||
minibatchSize=128
|
||||
learningRatesPerMB=0.01*20:0.003*15:0.001
|
||||
learningRatesPerMB=0.01*20:0.003*12:0.001
|
||||
momentumPerMB=0.9
|
||||
maxEpochs=60
|
||||
gradUpdateType=None
|
||||
L2RegWeight=0.0005
|
||||
dropoutRate=0*10:0.5
|
||||
dropoutRate=0*5:0.5
|
||||
|
||||
numMBsToShowResult=10
|
||||
]
|
||||
|
@ -41,6 +41,9 @@ Train=[
|
|||
width=224
|
||||
height=224
|
||||
channels=3
|
||||
cropType=Random
|
||||
cropRatio=0.9
|
||||
meanFile=$WorkDir$/ImageNet1K_mean.xml
|
||||
]
|
||||
labels=[
|
||||
labelDim=1000
|
||||
|
@ -66,6 +69,8 @@ Test=[
|
|||
width=224
|
||||
height=224
|
||||
channels=3
|
||||
cropType=Center
|
||||
meanFile=$WorkDir$/ImageNet1K_mean.xml
|
||||
]
|
||||
labels=[
|
||||
labelDim=1000
|
||||
|
|
|
@ -7,9 +7,10 @@ ndlMnistMacros = [
|
|||
ImageC = 3
|
||||
LabelDim = 1000
|
||||
|
||||
#features = ImageInput(ImageW, ImageH, ImageC, tag = feature)
|
||||
#featOffs = Const(128, rows = 150528)
|
||||
#featScaled = Minus(features, featOffs)
|
||||
features = ImageInput(ImageW, ImageH, ImageC, tag = feature)
|
||||
featOffs = Const(128, rows = 150528)
|
||||
featScaled = Minus(features, featOffs)
|
||||
labels = Input(LabelDim, tag = label)
|
||||
|
||||
conv1WScale = 0.95
|
||||
|
@ -38,7 +39,8 @@ DNN=[
|
|||
hStride1 = 3
|
||||
vStride1 = 3
|
||||
# weight[cMap1, kW1 * kH1 * ImageC]
|
||||
conv1_act = ConvReLULayer(featScaled, cMap1, 363, kW1, kH1, hStride1, vStride1, conv1WScale, conv1BValue)
|
||||
#conv1_act = ConvReLULayer(featScaled, cMap1, 363, kW1, kH1, hStride1, vStride1, conv1WScale, conv1BValue)
|
||||
conv1_act = ConvReLULayer(features, cMap1, 363, kW1, kH1, hStride1, vStride1, conv1WScale, conv1BValue)
|
||||
|
||||
# pool1
|
||||
pool1W = 3
|
||||
|
|
Загрузка…
Ссылка в новой задаче