Make AlignedMemory the means of passing in memory (#86)

This commit is contained in:
Kenneth Heafield 2021-04-06 13:23:55 +01:00 коммит произвёл GitHub
Родитель f654ab0f71
Коммит 27a3a3253f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 11 добавлений и 38 удалений

Просмотреть файл

@ -19,11 +19,8 @@ int main(int argc, char **argv) {
auto options = configParser.parseOptions(argc, argv, true);
std::string config = options->asYamlString();
// Prepare model byte array
marian::bergamot::AlignedMemory modelBytes = marian::bergamot::getModelMemoryFromConfig(options);
// Route the config string to construct marian model through TranslationModel
TranslationModel model(config, modelBytes.begin());
TranslationModel model(config, marian::bergamot::getModelMemoryFromConfig(options));
TranslationRequest translationRequest;
std::vector<std::string> texts;

Просмотреть файл

@ -17,6 +17,7 @@
// All local project includes
#include "TranslationRequest.h"
#include "TranslationResult.h"
#include "translator/definitions.h"
#include "translator/service.h"
/* A Translation model that translates a plain (without any markups and emojis)
@ -34,7 +35,8 @@ public:
* the bytes of a model.bin.
*/
TranslationModel(const std::string &config,
const void *model_memory = nullptr);
marian::bergamot::AlignedMemory modelMemory = marian::bergamot::AlignedMemory(),
marian::bergamot::AlignedMemory shortlistMemory = marian::bergamot::AlignedMemory());
~TranslationModel();

Просмотреть файл

@ -12,8 +12,9 @@
#include "translator/service.h"
TranslationModel::TranslationModel(const std::string &config,
const void *model_memory)
: service_(config, model_memory) {}
marian::bergamot::AlignedMemory model_memory,
marian::bergamot::AlignedMemory lexical_memory)
: service_(config, std::move(model_memory), std::move(lexical_memory)) {}
TranslationModel::~TranslationModel() {}

Просмотреть файл

@ -18,33 +18,6 @@
namespace marian {
namespace bergamot {
// Hack code to construct AlignedMemory* from void*
inline AlignedMemory hackModel(const void* modelMemory) {
if(modelMemory != nullptr){
// Here is a hack to make TranslationModel works
size_t modelMemorySize = 73837568; // Hack: model memory size should be changed to actual model size
AlignedMemory alignedMemory(modelMemorySize);
memcpy(alignedMemory.begin(), modelMemory, modelMemorySize);
return alignedMemory;
} else {
return AlignedMemory();
}
}
inline AlignedMemory hackShortLis(const void* shortlistMemory) {
if(shortlistMemory!= nullptr) {
// Hacks to obtain shortlist memory size as this will be checked during construction
size_t shortlistMemorySize = sizeof(uint64_t) * (6 + *((uint64_t*)shortlistMemory+4))
+ sizeof(uint32_t) * *((uint64_t*)shortlistMemory+5);
// Here is a hack to make TranslationModel works
AlignedMemory alignedMemory(shortlistMemorySize);
memcpy(alignedMemory.begin(), shortlistMemory, shortlistMemorySize);
return alignedMemory;
}else {
return AlignedMemory();
}
}
/// Service exposes methods to translate an incoming blob of text to the
/// Consumer of bergamot API.
///
@ -72,15 +45,15 @@ public:
explicit Service(Ptr<Options> options) : Service(options, AlignedMemory(), AlignedMemory()){}
/// Construct Service from a string configuration.
/// Construct Service from a string configuration.
/// @param [in] config string parsable as YAML expected to adhere with marian
/// config
/// @param [in] model_memory byte array (aligned to 256!!!) that contains the
/// bytes of a model.bin. Optional, defaults to nullptr when not used
/// bytes of a model.bin. Optional.
/// @param [in] shortlistMemory byte array of shortlist (aligned to 64)
explicit Service(const std::string &config,
const void* modelMemory = nullptr, const void* shortlistMemory = nullptr)
: Service(parseOptions(config), hackModel(modelMemory), hackShortLis(shortlistMemory)) {}
AlignedMemory modelMemory = AlignedMemory(), AlignedMemory shortlistMemory = AlignedMemory())
: Service(parseOptions(config), std::move(modelMemory), std::move(shortlistMemory)) {}
/// Explicit destructor to clean up after any threads initialized in
/// asynchronous operation mode.