91 строка
3.4 KiB
Lua
91 строка
3.4 KiB
Lua
local ffi = require 'ffi'
|
|
local util = require('multiverso.util')
|
|
|
|
local tbh = torch.class('MatrixTableHandler')
|
|
|
|
ffi.cdef[[
|
|
void MV_NewMatrixTable(int num_row, int num_col, TableHandler* out);
|
|
void MV_GetMatrixTableAll(TableHandler handler, float* data, int size);
|
|
void MV_AddMatrixTableAll(TableHandler handler, float* data, int size);
|
|
void MV_AddAsyncMatrixTableAll(TableHandler handler, float* data, int size);
|
|
void MV_GetMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
|
void MV_AddMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
|
void MV_AddAsyncMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
|
]]
|
|
|
|
function tbh:new(num_row, num_col, init_value)
|
|
tbh = {}
|
|
num_row = num_row or 0
|
|
num_col = num_col or 0
|
|
setmetatable(tbh, self)
|
|
self.__index = self
|
|
tbh._handler = ffi.new("TableHandler[1]")
|
|
tbh._num_row = ffi.new("int", num_row)
|
|
tbh._num_col = ffi.new("int", num_col)
|
|
tbh._size = ffi.new("int", num_row * num_col)
|
|
libmv.MV_NewMatrixTable(
|
|
tbh._num_row,
|
|
tbh._num_col,
|
|
tbh._handler
|
|
)
|
|
local init = require 'multiverso.init'
|
|
if init_value ~= nil then
|
|
init_value = init_value:float()
|
|
-- sync add is used because we want to make sure that the initial value
|
|
-- has taken effect when the call returns. No matter whether it is
|
|
-- master worker, we should call add to make sure it works in sync
|
|
-- mode
|
|
|
|
if init.worker_id() == 0 then
|
|
self.add(tbh, init_value, nil, true)
|
|
else
|
|
self.add(tbh, init_value:clone():zero(), nil, true)
|
|
end
|
|
end
|
|
return tbh
|
|
end
|
|
|
|
function tbh:get(row_ids)
|
|
if row_ids == nil then
|
|
cdata = ffi.new("float[?]", self._size)
|
|
libmv.MV_GetMatrixTableAll(self._handler[0], cdata, self._size)
|
|
data = util.cdata2tensor(cdata, tonumber(self._size))
|
|
return torch.reshape(data, tonumber(self._num_row), tonumber(self._num_col))
|
|
else
|
|
cdata = ffi.new("float[?]", #row_ids * self._num_col)
|
|
crow_ids = util.tensor2cdata(row_ids, 'int')
|
|
crow_ids_n = ffi.new("int", #row_ids)
|
|
libmv.MV_GetMatrixTableByRows(self._handler[0], cdata,
|
|
crow_ids_n * self._num_col,
|
|
crow_ids, crow_ids_n)
|
|
data = util.cdata2tensor(cdata, tonumber(#row_ids * self._num_col))
|
|
return torch.reshape(data, #row_ids, tonumber(self._num_col))
|
|
end
|
|
end
|
|
|
|
function tbh:add(data, row_ids, sync)
|
|
sync = sync or false
|
|
cdata = util.tensor2cdata(data)
|
|
if row_ids == nil then
|
|
if sync then
|
|
libmv.MV_AddMatrixTableAll(self._handler[0], cdata, self._size)
|
|
else
|
|
libmv.MV_AddAsyncMatrixTableAll(self._handler[0], cdata, self._size)
|
|
end
|
|
else
|
|
crow_ids = util.tensor2cdata(row_ids, 'int')
|
|
crow_ids_n = ffi.new("int", #row_ids)
|
|
if sync then
|
|
libmv.MV_AddMatrixTableByRows(self._handler[0], cdata,
|
|
crow_ids_n * self._num_col,
|
|
crow_ids, crow_ids_n)
|
|
else
|
|
libmv.MV_AddAsyncMatrixTableByRows(self._handler[0], cdata,
|
|
crow_ids_n * self._num_col,
|
|
crow_ids, crow_ids_n)
|
|
end
|
|
end
|
|
end
|
|
|
|
return tbh
|