[torch/lua] Split ArrayTableHandler and MatrixTableHandler.
This commit is contained in:
Родитель
41f905bf35
Коммит
4934722136
|
@ -0,0 +1,37 @@
|
||||||
|
local ffi = require 'ffi'
|
||||||
|
local util = require('multiverso.util')
|
||||||
|
|
||||||
|
local tbh = torch.class('ArrayTableHanlder')
|
||||||
|
|
||||||
|
ffi.cdef[[
|
||||||
|
void MV_NewArrayTable(int size, TableHandler* out);
|
||||||
|
void MV_GetArrayTable(TableHandler handler, float* data, int size);
|
||||||
|
void MV_AddArrayTable(TableHandler handler, float* data, int size);
|
||||||
|
]]
|
||||||
|
|
||||||
|
function tbh:new(size)
|
||||||
|
tbh = {}
|
||||||
|
size = size or 0
|
||||||
|
setmetatable(tbh, self)
|
||||||
|
self.__index = self
|
||||||
|
tbh._handler = ffi.new("TableHandler[1]")
|
||||||
|
tbh._size = ffi.new("int", size)
|
||||||
|
libmv.MV_NewArrayTable(
|
||||||
|
tbh._size,
|
||||||
|
tbh._handler
|
||||||
|
)
|
||||||
|
return tbh
|
||||||
|
end
|
||||||
|
|
||||||
|
function tbh:get()
|
||||||
|
cdata = ffi.new("float[?]", self._size)
|
||||||
|
libmv.MV_GetArrayTable(self._handler[0], cdata, self._size)
|
||||||
|
return util.cdata2tensor(cdata, tonumber(self._size))
|
||||||
|
end
|
||||||
|
|
||||||
|
function tbh:add(data)
|
||||||
|
cdata = util.tensor2cdata(data)
|
||||||
|
libmv.MV_AddArrayTable(self._handler[0], cdata, self._size)
|
||||||
|
end
|
||||||
|
|
||||||
|
return tbh
|
|
@ -0,0 +1,63 @@
|
||||||
|
local ffi = require 'ffi'
|
||||||
|
local util = require('multiverso.util')
|
||||||
|
|
||||||
|
local tbh = torch.class('MatrixTableHandler')
|
||||||
|
|
||||||
|
ffi.cdef[[
|
||||||
|
void MV_NewMatrixTable(int num_row, int num_col, TableHandler* out);
|
||||||
|
void MV_GetMatrixTableAll(TableHandler handler, float* data, int size);
|
||||||
|
void MV_AddMatrixTableAll(TableHandler handler, float* data, int size);
|
||||||
|
void MV_GetMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
||||||
|
void MV_AddMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
||||||
|
]]
|
||||||
|
|
||||||
|
function tbh:new(num_row, num_col)
|
||||||
|
tbh = {}
|
||||||
|
num_row = num_row or 0
|
||||||
|
num_col = num_col or 0
|
||||||
|
setmetatable(tbh, self)
|
||||||
|
self.__index = self
|
||||||
|
tbh._handler = ffi.new("TableHandler[1]")
|
||||||
|
tbh._num_row = ffi.new("int", num_row)
|
||||||
|
tbh._num_col = ffi.new("int", num_col)
|
||||||
|
tbh._size = ffi.new("int", num_row * num_col)
|
||||||
|
libmv.MV_NewMatrixTable(
|
||||||
|
tbh._num_row,
|
||||||
|
tbh._num_col,
|
||||||
|
tbh._handler
|
||||||
|
)
|
||||||
|
return tbh
|
||||||
|
end
|
||||||
|
|
||||||
|
function tbh:get(row_ids)
|
||||||
|
if row_ids == nil then
|
||||||
|
cdata = ffi.new("float[?]", self._size)
|
||||||
|
libmv.MV_GetMatrixTableAll(self._handler[0], cdata, self._size)
|
||||||
|
data = util.cdata2tensor(cdata, tonumber(self._size))
|
||||||
|
return torch.reshape(data, tonumber(self._num_row), tonumber(self._num_col))
|
||||||
|
else
|
||||||
|
cdata = ffi.new("float[?]", #row_ids * self._num_col)
|
||||||
|
crow_ids = util.tensor2cdata(row_ids, 'int')
|
||||||
|
crow_ids_n = ffi.new("int", #row_ids)
|
||||||
|
libmv.MV_GetMatrixTableByRows(self._handler[0], cdata,
|
||||||
|
crow_ids_n * self._num_col,
|
||||||
|
crow_ids, crow_ids_n)
|
||||||
|
data = util.cdata2tensor(cdata, tonumber(#row_ids * self._num_col))
|
||||||
|
return torch.reshape(data, #row_ids, tonumber(self._num_col))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function tbh:add(data, row_ids)
|
||||||
|
cdata = util.tensor2cdata(data)
|
||||||
|
if row_ids == nil then
|
||||||
|
libmv.MV_AddMatrixTableAll(self._handler[0], cdata, self._size)
|
||||||
|
else
|
||||||
|
crow_ids = util.tensor2cdata(row_ids, 'int')
|
||||||
|
crow_ids_n = ffi.new("int", #row_ids)
|
||||||
|
libmv.MV_AddMatrixTableByRows(self._handler[0], cdata,
|
||||||
|
crow_ids_n * self._num_col,
|
||||||
|
crow_ids, crow_ids_n)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return tbh
|
|
@ -13,23 +13,12 @@ ffi.cdef[[
|
||||||
int MV_NumWorkers();
|
int MV_NumWorkers();
|
||||||
int MV_WorkerId();
|
int MV_WorkerId();
|
||||||
int MV_ServerId();
|
int MV_ServerId();
|
||||||
|
|
||||||
// Array Table
|
|
||||||
void MV_NewArrayTable(int size, TableHandler* out);
|
|
||||||
void MV_GetArrayTable(TableHandler handler, float* data, int size);
|
|
||||||
void MV_AddArrayTable(TableHandler handler, float* data, int size);
|
|
||||||
|
|
||||||
// Matrix Table
|
|
||||||
void MV_NewMatrixTable(int num_row, int num_col, TableHandler* out);
|
|
||||||
void MV_GetMatrixTableAll(TableHandler handler, float* data, int size);
|
|
||||||
void MV_AddMatrixTableAll(TableHandler handler, float* data, int size);
|
|
||||||
void MV_GetMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
|
||||||
void MV_AddMatrixTableByRows(TableHandler handler, float* data, int size, int row_ids[], int row_ids_n);
|
|
||||||
]]
|
]]
|
||||||
|
|
||||||
package.cpath = "../../build/src/?.so;" .. package.cpath
|
libmv = ffi.load('libmultiverso', 'true')
|
||||||
libmv_path = package.searchpath('libmultiverso', package.cpath, '')
|
|
||||||
libmv = ffi.load(libmv_path)
|
mv.ArrayTableHandler = require('multiverso.ArrayTableHandler')
|
||||||
|
mv.MatrixTableHandler = require('multiverso.MatrixTableHandler')
|
||||||
|
|
||||||
function mv.init(args)
|
function mv.init(args)
|
||||||
args = args or {}
|
args = args or {}
|
||||||
|
@ -62,82 +51,4 @@ function mv.server_id()
|
||||||
return libmv.MV_ServerId()
|
return libmv.MV_ServerId()
|
||||||
end
|
end
|
||||||
|
|
||||||
mv.ArrayTableHandler = {}
|
|
||||||
|
|
||||||
function mv.ArrayTableHandler:new(size)
|
|
||||||
tbh = {}
|
|
||||||
size = size or 0
|
|
||||||
setmetatable(tbh, self)
|
|
||||||
self.__index = self
|
|
||||||
tbh._handler = ffi.new("TableHandler[1]")
|
|
||||||
tbh._size = ffi.new("int", size)
|
|
||||||
libmv.MV_NewArrayTable(
|
|
||||||
tbh._size,
|
|
||||||
tbh._handler
|
|
||||||
)
|
|
||||||
return tbh
|
|
||||||
end
|
|
||||||
|
|
||||||
function mv.ArrayTableHandler:get()
|
|
||||||
cdata = ffi.new("float[?]", self._size)
|
|
||||||
libmv.MV_GetArrayTable(self._handler[0], cdata, self._size)
|
|
||||||
return util.cdata2tensor(cdata, tonumber(self._size))
|
|
||||||
end
|
|
||||||
|
|
||||||
function mv.ArrayTableHandler:add(data)
|
|
||||||
cdata = util.tensor2cdata(data)
|
|
||||||
libmv.MV_AddArrayTable(self._handler[0], cdata, self._size)
|
|
||||||
end
|
|
||||||
|
|
||||||
mv.MatrixTableHandler = {}
|
|
||||||
|
|
||||||
function mv.MatrixTableHandler:new(num_row, num_col)
|
|
||||||
tbh = {}
|
|
||||||
num_row = num_row or 0
|
|
||||||
num_col = num_col or 0
|
|
||||||
setmetatable(tbh, self)
|
|
||||||
self.__index = self
|
|
||||||
tbh._handler = ffi.new("TableHandler[1]")
|
|
||||||
tbh._num_row = ffi.new("int", num_row)
|
|
||||||
tbh._num_col = ffi.new("int", num_col)
|
|
||||||
tbh._size = ffi.new("int", num_row * num_col)
|
|
||||||
libmv.MV_NewMatrixTable(
|
|
||||||
tbh._num_row,
|
|
||||||
tbh._num_col,
|
|
||||||
tbh._handler
|
|
||||||
)
|
|
||||||
return tbh
|
|
||||||
end
|
|
||||||
|
|
||||||
function mv.MatrixTableHandler:get(row_ids)
|
|
||||||
if row_ids == nil then
|
|
||||||
cdata = ffi.new("float[?]", self._size)
|
|
||||||
libmv.MV_GetMatrixTableAll(self._handler[0], cdata, self._size)
|
|
||||||
data = util.cdata2tensor(cdata, tonumber(self._size))
|
|
||||||
return torch.reshape(data, tonumber(self._num_row), tonumber(self._num_col))
|
|
||||||
else
|
|
||||||
cdata = ffi.new("float[?]", #row_ids * self._num_col)
|
|
||||||
crow_ids = util.tensor2cdata(row_ids, 'int')
|
|
||||||
crow_ids_n = ffi.new("int", #row_ids)
|
|
||||||
libmv.MV_GetMatrixTableByRows(self._handler[0], cdata,
|
|
||||||
crow_ids_n * self._num_col,
|
|
||||||
crow_ids, crow_ids_n)
|
|
||||||
data = util.cdata2tensor(cdata, tonumber(#row_ids * self._num_col))
|
|
||||||
return torch.reshape(data, #row_ids, tonumber(self._num_col))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function mv.MatrixTableHandler:add(data, row_ids)
|
|
||||||
cdata = util.tensor2cdata(data)
|
|
||||||
if row_ids == nil then
|
|
||||||
libmv.MV_AddMatrixTableAll(self._handler[0], cdata, self._size)
|
|
||||||
else
|
|
||||||
crow_ids = util.tensor2cdata(row_ids, 'int')
|
|
||||||
crow_ids_n = ffi.new("int", #row_ids)
|
|
||||||
libmv.MV_AddMatrixTableByRows(self._handler[0], cdata,
|
|
||||||
crow_ids_n * self._num_col,
|
|
||||||
crow_ids, crow_ids_n)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return mv
|
return mv
|
||||||
|
|
Загрузка…
Ссылка в новой задаче