我需要在torch7中初始化具有索引相关函数的3D张量,即 func = function(i,j,k) --i, j is the index of an element in the tensor return i*j*k --do operations within func which're dependent of i, jend 然后我初始化一个像这
func = function(i,j,k) --i, j is the index of an element in the tensor return i*j*k --do operations within func which're dependent of i, j end
然后我初始化一个像这样的3D张量A:
for i=1,A:size(1) do for j=1,A:size(2) do for k=1,A:size(3) do A[{i,j,k}] = func(i,j,k) end end end
但是这段代码运行得非常慢,我发现它占总运行时间的92%.有没有更有效的方法来初始化火炬7中的3D张量?
请参阅Tensor:apply
的文档
These functions apply a function to each element of the tensor on
which the method is called (self). These methods are much faster than
using a for loop in Lua.
docs中的示例基于其索引i(在内存中)初始化2D数组.下面是3维的扩展示例,低于N-D张量的扩展示例.在我的机器上使用apply方法要快得多:
require 'torch' A = torch.Tensor(100, 100, 1000) B = torch.Tensor(100, 100, 1000) function func(i,j,k) return i*j*k end t = os.clock() for i=1,A:size(1) do for j=1,A:size(2) do for k=1,A:size(3) do A[{i, j, k}] = i * j * k end end end print("Original time:", os.difftime(os.clock(), t)) t = os.clock() function forindices(A, func) local i = 1 local j = 1 local k = 0 local d3 = A:size(3) local d2 = A:size(2) return function() k = k + 1 if k > d3 then k = 1 j = j + 1 if j > d2 then j = 1 i = i + 1 end end return func(i, j, k) end end B:apply(forindices(A, func)) print("Apply method:", os.difftime(os.clock(), t))
编辑
这适用于任何Tensor对象:
function tabulate(A, f) local idx = {} local ndims = A:dim() local dim = A:size() idx[ndims] = 0 for i=1, (ndims - 1) do idx[i] = 1 end return A:apply(function() for i=ndims, 0, -1 do idx[i] = idx[i] + 1 if idx[i] <= dim[i] then break end idx[i] = 1 end return f(unpack(idx)) end) end -- usage for 3D case. tabulate(A, function(i, j, k) return i * j * k end)