This commit is contained in:
feiga 2016-01-05 11:48:04 +08:00
Родитель 7fa03f9a1b
Коммит 15ee80712f
8 изменённых файлов: 265 добавлений и 227 удалений

Просмотреть файл

@ -1,10 +1,13 @@
#include "alias_table.h"
#include "common.h"
#include "model.h"
#include "util.h"
#include "meta.h"
#include <multiverso/lock.h>
#include <multiverso/log.h>
#include <multiverso/row.h>
#include <multiverso/row_iter.h>
namespace multiverso { namespace lightlda

Просмотреть файл

@ -28,7 +28,6 @@ namespace multiverso { namespace lightlda
for (int32_t i = 0; i < Config::num_local_workers; ++i)
{
Trainer* trainer = new Trainer(alias_table, barrier, &meta);
trainer->Init();
trainers.push_back(trainer);
}

Просмотреть файл

@ -1,165 +1,224 @@
#include "model.h"
#include "meta.h"
#ifdef _MSC_VER
// TODO
#elif
#include <dirent.h>
#include <regex.h>
#endif
#include <algorithm>
#include <fstream>
#include <sstream>
#include <regex.h>
#include <algorithm>
#include "meta.h"
#include "trainer.h"
#include <multiverso/log.h>
#include <multiverso/multiverso.h>
namespace multiverso { namespace lightlda
{
LocalModel::LocalModel(Meta * meta): word_topic_table_(nullptr),
summary_table_(nullptr), meta_(meta)
{
CreateTable();
}
LocalModel::LocalModel(Meta * meta) : word_topic_table_(nullptr),
summary_table_(nullptr), meta_(meta)
{
CreateTable();
}
void LocalModel::Init()
{
LoadTable();
}
void LocalModel::Init()
{
LoadTable();
}
void LocalModel::CreateTable()
{
int32_t num_vocabs = Config::num_vocabs;
int32_t num_topics = Config::num_topics;
multiverso::Format dense_format = multiverso::Format::Dense;
multiverso::Format sparse_format = multiverso::Format::Sparse;
Type int_type = Type::Int;
Type longlong_type = Type::LongLong;
void LocalModel::CreateTable()
{
int32_t num_vocabs = Config::num_vocabs;
int32_t num_topics = Config::num_topics;
multiverso::Format dense_format = multiverso::Format::Dense;
multiverso::Format sparse_format = multiverso::Format::Sparse;
Type int_type = Type::Int;
Type longlong_type = Type::LongLong;
word_topic_table_.reset(new Table(kWordTopicTable, num_vocabs, num_topics,
word_topic_table_.reset(new Table(kWordTopicTable, num_vocabs, num_topics,
int_type, dense_format));
summary_table_.reset(new Table(kSummaryRow, 1, num_topics,
summary_table_.reset(new Table(kSummaryRow, 1, num_topics,
longlong_type, dense_format));
}
void LocalModel::LoadTable()
{
Log::Info("loading model\n");
//set regex for model files
regex_t model_wordtopic_regex;
regex_t model_summary_regex;
std::string prefix = "server_[[:digit:]]+_table_";
std::string suffix = ".model";
std::ostringstream wordtopic_regstr;
wordtopic_regstr << prefix << kWordTopicTable << suffix;
std::ostringstream summary_regstr;
summary_regstr << prefix << kSummaryRow << suffix;
regcomp(&model_wordtopic_regex, wordtopic_regstr.str().c_str(), REG_EXTENDED);
regcomp(&model_summary_regex, summary_regstr.str().c_str(), REG_EXTENDED);
//look for model files & load them
DIR *dir;
struct dirent *ent;
if ((dir = opendir (Config::input_dir.c_str())) != NULL)
{
while ((ent = readdir (dir)) != NULL)
{
if(!regexec(&model_wordtopic_regex, ent->d_name, 0, NULL, 0))
{
Log::Info("loading word topic table[%s]\n", ent->d_name);
LoadWordTopicTable(Config::input_dir + "/" + ent->d_name);
}
else if(!regexec(&model_summary_regex, ent->d_name, 0, NULL, 0))
{
Log::Info("loading summary table[%s]\n", ent->d_name);
LoadSummaryTable(Config::input_dir + "/" + ent->d_name);
}
}
closedir (dir);
}
else
{
Log::Fatal("model dir does not exist : %s\n", Config::input_dir.c_str());
}
regfree(&model_wordtopic_regex);
regfree(&model_summary_regex);
}
void LocalModel::LoadWordTopicTable(const std::string& model_fname)
{
multiverso::Format dense_format = multiverso::Format::Dense;
multiverso::Format sparse_format = multiverso::Format::Sparse;
std::ifstream model_file(model_fname, std::ios::in);
std::string line;
while(getline(model_file, line))
void LocalModel::LoadTable()
{
std::stringstream ss(line);
std::string word;
std::string fea;
std::vector<std::string> feas;
int32_t word_id, topic_id, freq;
//assign word id
ss >> word;
word_id = std::stoi(word);
if(meta_->tf(word_id) > 0)
{
//set row
if (meta_->tf(word_id) * kLoadFactor > Config::num_topics)
#ifdef _MSC_VER
Log::Fatal("Not implementent yet on Windows\n");
#elif
Log::Info("loading model\n");
//set regex for model files
regex_t model_wordtopic_regex;
regex_t model_summary_regex;
std::string prefix = "server_[[:digit:]]+_table_";
std::string suffix = ".model";
std::ostringstream wordtopic_regstr;
wordtopic_regstr << prefix << kWordTopicTable << suffix;
std::ostringstream summary_regstr;
summary_regstr << prefix << kSummaryRow << suffix;
regcomp(&model_wordtopic_regex, wordtopic_regstr.str().c_str(), REG_EXTENDED);
regcomp(&model_summary_regex, summary_regstr.str().c_str(), REG_EXTENDED);
//look for model files & load them
DIR *dir;
struct dirent *ent;
if ((dir = opendir(Config::input_dir.c_str())) != NULL)
{
word_topic_table_->SetRow(word_id, dense_format, Config::num_topics);
while ((ent = readdir(dir)) != NULL)
{
if (!regexec(&model_wordtopic_regex, ent->d_name, 0, NULL, 0))
{
Log::Info("loading word topic table[%s]\n", ent->d_name);
LoadWordTopicTable(Config::input_dir + "/" + ent->d_name);
}
else if (!regexec(&model_summary_regex, ent->d_name, 0, NULL, 0))
{
Log::Info("loading summary table[%s]\n", ent->d_name);
LoadSummaryTable(Config::input_dir + "/" + ent->d_name);
}
}
closedir(dir);
}
else
{
word_topic_table_->SetRow(word_id, sparse_format, meta_->tf(word_id) * kLoadFactor);
Log::Fatal("model dir does not exist : %s\n", Config::input_dir.c_str());
}
//get row
Row<int32_t> * row = static_cast<Row<int32_t>*>
(word_topic_table_->GetRow(word_id));
//add features to row
while (ss >> fea)
{
size_t pos = fea.find_last_of(':');
if(pos != std::string::npos)
{
topic_id = std::stoi(fea.substr(0, pos));
freq = std::stoi(fea.substr(pos + 1));
row->Add(topic_id, freq);
}
else
{
Log::Fatal("bad format of model: %s\n", line.c_str());
}
}
}
regfree(&model_wordtopic_regex);
regfree(&model_summary_regex);
#endif
}
model_file.close();
}
void LocalModel::LoadSummaryTable(const std::string& model_fname)
{
Row<int64_t> * row = static_cast<Row<int64_t>*>
(summary_table_->GetRow(0));
std::ifstream model_file(model_fname, std::ios::in);
std::string line;
if(getline(model_file, line))
void LocalModel::LoadWordTopicTable(const std::string& model_fname)
{
std::stringstream ss(line);
std::string fea;
std::vector<std::string> feas;
int32_t topic_id, freq;
//skip word id
ss >> fea;
//add features to row
while (ss >> fea)
{
size_t pos = fea.find_last_of(':');
if(pos != std::string::npos)
multiverso::Format dense_format = multiverso::Format::Dense;
multiverso::Format sparse_format = multiverso::Format::Sparse;
std::ifstream model_file(model_fname, std::ios::in);
std::string line;
while (getline(model_file, line))
{
topic_id = std::stoi(fea.substr(0, pos));
freq = std::stoi(fea.substr(pos + 1));
row->Add(topic_id, freq);
std::stringstream ss(line);
std::string word;
std::string fea;
std::vector<std::string> feas;
int32_t word_id, topic_id, freq;
//assign word id
ss >> word;
word_id = std::stoi(word);
if (meta_->tf(word_id) > 0)
{
//set row
if (meta_->tf(word_id) * kLoadFactor > Config::num_topics)
{
word_topic_table_->SetRow(word_id, dense_format,
Config::num_topics);
}
else
{
word_topic_table_->SetRow(word_id, sparse_format,
meta_->tf(word_id) * kLoadFactor);
}
//get row
Row<int32_t> * row = static_cast<Row<int32_t>*>
(word_topic_table_->GetRow(word_id));
//add features to row
while (ss >> fea)
{
size_t pos = fea.find_last_of(':');
if (pos != std::string::npos)
{
topic_id = std::stoi(fea.substr(0, pos));
freq = std::stoi(fea.substr(pos + 1));
row->Add(topic_id, freq);
}
else
{
Log::Fatal("bad format of model: %s\n", line.c_str());
}
}
}
}
else
{
Log::Fatal("bad format of model: %s\n", line.c_str());
}
}
model_file.close();
}
model_file.close();
}
void LocalModel::LoadSummaryTable(const std::string& model_fname)
{
Row<int64_t> * row = static_cast<Row<int64_t>*>
(summary_table_->GetRow(0));
std::ifstream model_file(model_fname, std::ios::in);
std::string line;
if (getline(model_file, line))
{
std::stringstream ss(line);
std::string fea;
std::vector<std::string> feas;
int32_t topic_id, freq;
//skip word id
ss >> fea;
//add features to row
while (ss >> fea)
{
size_t pos = fea.find_last_of(':');
if (pos != std::string::npos)
{
topic_id = std::stoi(fea.substr(0, pos));
freq = std::stoi(fea.substr(pos + 1));
row->Add(topic_id, freq);
}
else
{
Log::Fatal("bad format of model: %s\n", line.c_str());
}
}
}
model_file.close();
}
void LocalModel::AddWordTopicRow(
integer_t word_id, integer_t topic_id, int32_t delta)
{
Log::Fatal("Not implemented yet\n");
}
void LocalModel::AddSummaryRow(integer_t topic_id, int64_t delta)
{
Log::Fatal("Not implemented yet\n");
}
Row<int32_t>& LocalModel::GetWordTopicRow(integer_t word)
{
return *(static_cast<Row<int32_t>*>(word_topic_table_->GetRow(word)));
}
Row<int64_t>& LocalModel::GetSummaryRow()
{
return *(static_cast<Row<int64_t>*>(summary_table_->GetRow(0)));
}
Row<int32_t>& PSModel::GetWordTopicRow(integer_t word_id)
{
return trainer_->GetRow<int32_t>(kWordTopicTable, word_id);
}
Row<int64_t>& PSModel::GetSummaryRow()
{
return trainer_->GetRow<int64_t>(kSummaryRow, 0);
}
void PSModel::AddWordTopicRow(
integer_t word_id, integer_t topic_id, int32_t delta)
{
trainer_->Add<int32_t>(kWordTopicTable, word_id, topic_id, delta);
}
void PSModel::AddSummaryRow(integer_t topic_id, int64_t delta)
{
trainer_->Add<int64_t>(kSummaryRow, 0, topic_id, delta);
}
} // namespace lightlda
} // namespace multiverso

Просмотреть файл

@ -3,105 +3,82 @@
* \brief define local model reader
*/
#ifndef LIGHTLDA_MODEL_H_
#define LIGHTLDA_MODEL_H_
#include "trainer.h"
#include "common.h"
#include <multiverso/multiverso.h>
#include <memory>
#include <string>
namespace multiverso { namespace lightlda
#include "common.h"
#include <multiverso/meta.h>
namespace multiverso
{
class Meta;
class Trainer;
template<typename T> class Row;
class Table;
/*! \brief interface for acceess to model */
class ModelBase
{
public:
virtual ~ModelBase() {}
virtual Row<int32_t>& GetWordTopicRow(integer_t word_id) = 0;
virtual Row<int64_t>& GetSummaryRow() = 0;
virtual void AddWordTopicRow(integer_t word_id, integer_t topic_id, int32_t delta) = 0;
virtual void AddSummaryRow(integer_t topic_id, int64_t delta) = 0;
};
namespace lightlda
{
class Meta;
class Trainer;
/*! \brief model based on local buffer */
class LocalModel: public ModelBase
{
public:
LocalModel(Meta * meta);
void Init();
/*! \brief interface for acceess to model */
class ModelBase
{
public:
virtual ~ModelBase() {}
virtual Row<int32_t>& GetWordTopicRow(integer_t word_id) = 0;
virtual Row<int64_t>& GetSummaryRow() = 0;
virtual void AddWordTopicRow(integer_t word_id, integer_t topic_id,
int32_t delta) = 0;
virtual void AddSummaryRow(integer_t topic_id, int64_t delta) = 0;
};
public:
Row<int32_t>& GetWordTopicRow(integer_t word_id) override;
Row<int64_t>& GetSummaryRow() override;
void AddWordTopicRow(integer_t word_id, integer_t topic_id, int32_t delta) override;
void AddSummaryRow(integer_t topic_id, int64_t delta) override;
/*! \brief model based on local buffer */
class LocalModel : public ModelBase
{
public:
explicit LocalModel(Meta * meta);
void Init();
private:
void CreateTable();
void LoadTable();
void LoadWordTopicTable(const std::string& model_fname);
void LoadSummaryTable(const std::string& model_fname);
Row<int32_t>& GetWordTopicRow(integer_t word_id) override;
Row<int64_t>& GetSummaryRow() override;
void AddWordTopicRow(integer_t word_id, integer_t topic_id,
int32_t delta) override;
void AddSummaryRow(integer_t topic_id, int64_t delta) override;
private:
std::unique_ptr<Table> word_topic_table_;
std::unique_ptr<Table> summary_table_;
Meta* meta_;
};
private:
void CreateTable();
void LoadTable();
void LoadWordTopicTable(const std::string& model_fname);
void LoadSummaryTable(const std::string& model_fname);
/*! \brief model based on parameter server */
class PSModel: public ModelBase
{
public:
PSModel(Trainer* trainer): trainer_(trainer) {}
public:
Row<int32_t>& GetWordTopicRow(integer_t word_id) override;
Row<int64_t>& GetSummaryRow() override;
void AddWordTopicRow(integer_t word_id, integer_t topic_id, int32_t delta) override;
void AddSummaryRow(integer_t topic_id, int64_t delta) override;
private:
Trainer* trainer_;
};
std::unique_ptr<Table> word_topic_table_;
std::unique_ptr<Table> summary_table_;
Meta* meta_;
// -- inline functions definition area --------------------------------- //
inline Row<int32_t>& LocalModel::GetWordTopicRow(integer_t word_id)
{
return *(static_cast<Row<int32_t>*>(word_topic_table_->GetRow(word_id)));
}
LocalModel(const LocalModel&) = delete;
void operator=(const LocalModel&) = delete;
};
inline Row<int64_t>& LocalModel::GetSummaryRow()
{
return *(static_cast<Row<int64_t>*>(summary_table_->GetRow(0)));
}
/*! \brief model based on parameter server */
class PSModel : public ModelBase
{
public:
explicit PSModel(Trainer* trainer) : trainer_(trainer) {}
inline void LocalModel::AddWordTopicRow(integer_t word_id, integer_t topic_id, int32_t delta) {}
Row<int32_t>& GetWordTopicRow(integer_t word_id) override;
Row<int64_t>& GetSummaryRow() override;
void AddWordTopicRow(integer_t word_id, integer_t topic_id,
int32_t delta) override;
void AddSummaryRow(integer_t topic_id, int64_t delta) override;
inline void LocalModel::AddSummaryRow(integer_t topic_id, int64_t delta) {}
private:
Trainer* trainer_;
inline Row<int32_t>& PSModel::GetWordTopicRow(integer_t word_id)
{
return trainer_->GetRow<int32_t>(kWordTopicTable, word_id);
}
inline Row<int64_t>& PSModel::GetSummaryRow()
{
return trainer_->GetRow<int64_t>(kSummaryRow, 0);
}
inline void PSModel::AddWordTopicRow(integer_t word_id, integer_t topic_id, int32_t delta)
{
trainer_->Add<int32_t>(kWordTopicTable, word_id, topic_id, delta);
}
inline void PSModel::AddSummaryRow(integer_t topic_id, int64_t delta)
{
trainer_->Add<int64_t>(kSummaryRow, 0, topic_id, delta);
}
// -- inline functions definition area --------------------------------- //
PSModel(const PSModel&) = delete;
void operator=(const PSModel&) = delete;
};
} // namespace lightlda
} // namespace multiverso

Просмотреть файл

@ -5,6 +5,7 @@
#include "document.h"
#include "model.h"
#include <multiverso/log.h>
#include <multiverso/row.h>
namespace multiverso { namespace lightlda

Просмотреть файл

@ -24,17 +24,13 @@ namespace multiverso { namespace lightlda
model_(nullptr)
{
sampler_ = new LightDocSampler();
}
void Trainer::Init()
{
model_ = new PSModel(this);
}
Trainer::~Trainer()
{
delete sampler_;
if(model_ != nullptr) delete model_;
delete model_;
}
void Trainer::TrainIteration(DataBlockBase* data_block)

Просмотреть файл

@ -25,7 +25,6 @@ namespace multiverso { namespace lightlda
public:
Trainer(AliasTable* alias, Barrier* barrier, Meta* meta);
~Trainer();
void Init();
/*!
* \brief Defines Trainning method for a data_block in one iteration
* \param data_block pointer to data block base

Просмотреть файл

@ -77,7 +77,8 @@
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
<IncludePath>$(SolutionDir)/../../multiverso/include;$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
<LibraryPath>$(SolutionDir)/../../multiverso/windows/x64/Release/;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);</LibraryPath>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
@ -137,6 +138,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>multiverso.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
@ -148,6 +150,7 @@
<ClCompile Include="..\..\src\eval.cpp" />
<ClCompile Include="..\..\src\lightlda.cpp" />
<ClCompile Include="..\..\src\meta.cpp" />
<ClCompile Include="..\..\src\model.cpp" />
<ClCompile Include="..\..\src\sampler.cpp" />
<ClCompile Include="..\..\src\trainer.cpp" />
</ItemGroup>
@ -159,6 +162,7 @@
<ClInclude Include="..\..\src\document.h" />
<ClInclude Include="..\..\src\eval.h" />
<ClInclude Include="..\..\src\meta.h" />
<ClInclude Include="..\..\src\model.h" />
<ClInclude Include="..\..\src\sampler.h" />
<ClInclude Include="..\..\src\trainer.h" />
<ClInclude Include="..\..\src\util.h" />