diff --git a/binding/lua/multiverso.lua b/binding/lua/multiverso.lua index e90c448..869ea36 100644 --- a/binding/lua/multiverso.lua +++ b/binding/lua/multiverso.lua @@ -3,6 +3,8 @@ mv = {} ffi = require('ffi') +util = require('util') + ffi.cdef[[ typedef void* TableHandler; void MV_Init(int* argc, char* argv[]); @@ -72,18 +74,11 @@ end function mv.ArrayTableHandler:get() cdata = ffi.new("float[?]", self._size) libmv.MV_GetArrayTable(self._handler[0], cdata, self._size) - data = {} - for i=1, tonumber(self._size) do - data[i] = cdata[i - 1] - end - return torch.Tensor(data) + return util.cdata2tensor(cdata, tonumber(self._size)) end function mv.ArrayTableHandler:add(data) - cdata = ffi.new("float[?]", self._size) - for i=1, tonumber(self._size) do - cdata[i - 1] = data[i] - end + cdata = util.tensor2cdata(data, tonumber(self._size)) libmv.MV_AddArrayTable(self._handler[0], cdata, self._size) end diff --git a/binding/lua/util.lua b/binding/lua/util.lua new file mode 100644 index 0000000..033ff40 --- /dev/null +++ b/binding/lua/util.lua @@ -0,0 +1,24 @@ +#!/usr/bin/env lua + +util = {} + +ffi = require('ffi') + +function util.cdata2tensor(cdata, size) + data = {} + for i=1, size do + data[i] = cdata[i - 1] + end + return torch.Tensor(data) +end + +function util.tensor2cdata(data, size, cdata_type) + cdata_type = cdata_type or "float[?]" + cdata = ffi.new(cdata_type, size) + for i=1, size do + cdata[i - 1] = data[i] + end + return cdata +end + +return util