2015-08-05 19:23:33 +03:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "DistGradHeader.h"
|
|
|
|
#include "MPIWrapper.h"
|
|
|
|
|
|
|
|
namespace Microsoft { namespace MSR { namespace CNTK {
|
|
|
|
|
2016-01-18 11:36:14 +03:00
|
|
|
template <class ElemType>
|
|
|
|
class IDistGradAggregator
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
IDistGradAggregator(MPIWrapper* mpi)
|
|
|
|
: m_mpi(mpi)
|
2015-08-05 19:23:33 +03:00
|
|
|
{
|
2016-01-18 11:36:14 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
virtual ~IDistGradAggregator()
|
|
|
|
{
|
|
|
|
}
|
|
|
|
|
|
|
|
// Returns a boolean indicating if any samples were processed
|
|
|
|
virtual bool AggregateGradients(const std::vector<Matrix<ElemType>*>& gradients, DistGradHeader* headerCPU, int epochNumber) = 0;
|
|
|
|
|
|
|
|
size_t NumProc()
|
|
|
|
{
|
|
|
|
return m_mpi->NumNodesInUse();
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t MyRank()
|
|
|
|
{
|
|
|
|
return m_mpi->CurrentNodeRank();
|
|
|
|
}
|
|
|
|
|
|
|
|
void WaitAll()
|
|
|
|
{
|
|
|
|
m_mpi->WaitAll();
|
|
|
|
}
|
|
|
|
|
|
|
|
protected:
|
|
|
|
MPIWrapper* m_mpi;
|
|
|
|
};
|
|
|
|
|
|
|
|
#define UsingIDistGradAggregatorMembers \
|
2016-01-22 11:23:03 +03:00
|
|
|
\
|
2016-01-18 11:36:14 +03:00
|
|
|
protected: \
|
|
|
|
using IDistGradAggregator<ElemType>::m_mpi; \
|
|
|
|
using IDistGradAggregator<ElemType>::NumProc; \
|
|
|
|
using IDistGradAggregator<ElemType>::MyRank
|
|
|
|
} } }
|