在其中一个函数中有一个循环,我们乘以3×1向量(我们称之为x) – 一个3×3矩阵(让我们称之为A) – 和x的转置,产生一个标量.代码具有整个逐元素乘法和加法的集合,并且非常麻烦:
val = x(1)*A(1,1)*x(1) + x(1)*A(1,2)*x(2) + x(1)*A(1,3)*x(3) + ... x(2)*A(2,1)*x(1) + x(2)*A(2,2)*x(2) + x(2)*A(2,3)*x(3) + ... x(3)*A(3,1)*x(1) + x(3)*A(3,2)*x(2) + x(3)*A(3,3)*x(3);
我想我只是替换它:
val = x*A*x';
令我惊讶的是,它的运行速度明显变慢(因为慢了4-5倍).仅仅是向量和矩阵是如此之小以至于MATLAB的优化不适用吗?
编辑:我改进了测试,以提供更准确的时间.我还优化了展开版本,现在比我最初的版本要好得多,但是随着你增加尺寸,矩阵乘法仍然会更快.EDIT2:为了确保JIT编译器正在处理展开的函数,我修改了代码以将生成的函数写为M文件.此外,比较现在可以看作是公平的,因为通过传递TIMEIT函数句柄来评估两个方法:timeit(@myfunc)
我不相信你的方法比合理大小的矩阵乘法更快.所以让我们比较两种方法.
我正在使用符号数学工具箱来帮助我获得x’* A * x等式的“展开”形式(尝试手动乘以20×20矩阵和20×1向量!):
function f = buildUnrolledFunction(N) % avoid regenerating files, CCODE below can be really slow! fname = sprintf('f%d',N); if exist([fname '.m'], 'file') f = str2func(fname); return end % construct symbolic vector/matrix of the specified size x = sym('x', [N 1]); A = sym('A', [N N]); % work out the expanded form of the matrix-multiplication % and convert it to a string s = ccode(expand(x.'*A*x)); % instead of char(.) to avoid x^2 % a bit of RegExp to fix the notation of the variable names % also convert indexing into linear indices: A(3,3) into A(9) s = regexprep(regexprep(s, '^.*=\s+', ''), ';$', ''); s = regexprep(regexprep(s, 'x(\d+)', 'x($1)'), 'A(\d+)_(\d+)', ... 'A(${ int2str(sub2ind([N N],str2num($1),str2num($2))) })'); % build an M-function from the string, and write it to file fid = fopen([fname '.m'], 'wt'); fprintf(fid, 'function v = %s(A,x)\nv = %s;\nend\n', fname, s); fclose(fid); % rehash path and return a function handle rehash clear(fname) f = str2func(fname); end
我试图通过避免取幂来优化生成的函数(我们更喜欢x * x到x ^ 2).我还将下标转换为线性索引(A(9)而不是A(3,3)).因此,对于n = 3,我们得到与您相同的等式:
>> s s = A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) + A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) + A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3)
给定上面构造M函数的方法,我们现在评估它的各种大小,并将它与matrix-multiplication形式进行比较(我把它放在一个单独的函数中来解释函数调用开销).我使用TIMEIT函数而不是tic / toc来获得更准确的计时.另外,要进行公平比较,每个方法都实现为一个M文件函数,它将所有需要的变量作为输入参数传递.
function results = testMatrixMultVsUnrolled() % vector/matrix size N_vec = 2:50; results = zeros(numel(N_vec),3); for ii = 1:numel(N_vec); % some random data N = N_vec(ii); x = rand(N,1); A = rand(N,N); % matrix multiplication f = @matMult; results(ii,1) = timeit(@() feval(f, A,x)); % unrolled equation f = buildUnrolledFunction(N); results(ii,2) = timeit(@() feval(f, A,x)); % check result results(ii,3) = norm(matMult(A,x) - f(A,x)); end % display results fprintf('N = %2d: mtimes = %.6f ms, unroll = %.6f ms [error = %g]\n', ... [N_vec(:) results(:,1:2)*1e3 results(:,3)]') plot(N_vec, results(:,1:2)*1e3, 'LineWidth',2) xlabel('size (N)'), ylabel('timing [msec]'), grid on legend({'mtimes','unrolled'}) title('Matrix multiplication: $$x^\mathsf{T}Ax$$', ... 'Interpreter','latex', 'FontSize',14) end function v = matMult(A,x) v = x.' * A * x; end
结果:
N = 2: mtimes = 0.008816 ms, unroll = 0.006793 ms [error = 0] N = 3: mtimes = 0.008957 ms, unroll = 0.007554 ms [error = 0] N = 4: mtimes = 0.009025 ms, unroll = 0.008261 ms [error = 4.44089e-16] N = 5: mtimes = 0.009075 ms, unroll = 0.008658 ms [error = 0] N = 6: mtimes = 0.009003 ms, unroll = 0.008689 ms [error = 8.88178e-16] N = 7: mtimes = 0.009234 ms, unroll = 0.009087 ms [error = 1.77636e-15] N = 8: mtimes = 0.008575 ms, unroll = 0.009744 ms [error = 8.88178e-16] N = 9: mtimes = 0.008601 ms, unroll = 0.011948 ms [error = 0] N = 10: mtimes = 0.009077 ms, unroll = 0.014052 ms [error = 0] N = 11: mtimes = 0.009339 ms, unroll = 0.015358 ms [error = 3.55271e-15] N = 12: mtimes = 0.009271 ms, unroll = 0.018494 ms [error = 3.55271e-15] N = 13: mtimes = 0.009166 ms, unroll = 0.020238 ms [error = 0] N = 14: mtimes = 0.009204 ms, unroll = 0.023326 ms [error = 7.10543e-15] N = 15: mtimes = 0.009396 ms, unroll = 0.024767 ms [error = 3.55271e-15] N = 16: mtimes = 0.009193 ms, unroll = 0.027294 ms [error = 2.4869e-14] N = 17: mtimes = 0.009182 ms, unroll = 0.029698 ms [error = 2.13163e-14] N = 18: mtimes = 0.009330 ms, unroll = 0.033295 ms [error = 7.10543e-15] N = 19: mtimes = 0.009411 ms, unroll = 0.152308 ms [error = 7.10543e-15] N = 20: mtimes = 0.009366 ms, unroll = 0.167336 ms [error = 7.10543e-15] N = 21: mtimes = 0.009335 ms, unroll = 0.183371 ms [error = 0] N = 22: mtimes = 0.009349 ms, unroll = 0.200859 ms [error = 7.10543e-14] N = 23: mtimes = 0.009411 ms, unroll = 0.218477 ms [error = 8.52651e-14] N = 24: mtimes = 0.009307 ms, unroll = 0.235668 ms [error = 4.26326e-14] N = 25: mtimes = 0.009425 ms, unroll = 0.256491 ms [error = 1.13687e-13] N = 26: mtimes = 0.009392 ms, unroll = 0.274879 ms [error = 7.10543e-15] N = 27: mtimes = 0.009515 ms, unroll = 0.296795 ms [error = 2.84217e-14] N = 28: mtimes = 0.009567 ms, unroll = 0.319032 ms [error = 5.68434e-14] N = 29: mtimes = 0.009548 ms, unroll = 0.339517 ms [error = 3.12639e-13] N = 30: mtimes = 0.009617 ms, unroll = 0.361897 ms [error = 1.7053e-13] N = 31: mtimes = 0.009672 ms, unroll = 0.387270 ms [error = 0] N = 32: mtimes = 0.009629 ms, unroll = 0.410932 ms [error = 1.42109e-13] N = 33: mtimes = 0.009605 ms, unroll = 0.434452 ms [error = 1.42109e-13] N = 34: mtimes = 0.009534 ms, unroll = 0.462961 ms [error = 0] N = 35: mtimes = 0.009696 ms, unroll = 0.489474 ms [error = 5.68434e-14] N = 36: mtimes = 0.009691 ms, unroll = 0.512198 ms [error = 8.52651e-14] N = 37: mtimes = 0.009671 ms, unroll = 0.544485 ms [error = 5.68434e-14] N = 38: mtimes = 0.009710 ms, unroll = 0.573564 ms [error = 8.52651e-14] N = 39: mtimes = 0.009946 ms, unroll = 0.604567 ms [error = 3.41061e-13] N = 40: mtimes = 0.009735 ms, unroll = 0.636640 ms [error = 3.12639e-13] N = 41: mtimes = 0.009858 ms, unroll = 0.665719 ms [error = 5.40012e-13] N = 42: mtimes = 0.009876 ms, unroll = 0.697364 ms [error = 0] N = 43: mtimes = 0.009956 ms, unroll = 0.730506 ms [error = 2.55795e-13] N = 44: mtimes = 0.009897 ms, unroll = 0.765358 ms [error = 4.26326e-13] N = 45: mtimes = 0.009991 ms, unroll = 0.800424 ms [error = 0] N = 46: mtimes = 0.009956 ms, unroll = 0.829717 ms [error = 2.27374e-13] N = 47: mtimes = 0.010210 ms, unroll = 0.865424 ms [error = 2.84217e-13] N = 48: mtimes = 0.010022 ms, unroll = 0.907974 ms [error = 3.97904e-13] N = 49: mtimes = 0.010098 ms, unroll = 0.944536 ms [error = 5.68434e-13] N = 50: mtimes = 0.010153 ms, unroll = 0.984486 ms [error = 4.54747e-13]
在小尺寸下,这两种方法的表现有些相似.尽管对于N <7,扩展版本击败了mtimes,但差异并不显着.一旦我们移过微小尺寸,矩阵乘法就会快几个数量级. 这并不奇怪;只有N = 20 formula是可怕的长,并涉及增加400个术语.由于MATLAB语言被解释,我怀疑这是非常有效的..
现在我同意调用外部函数与直接嵌入代码的开销有一定的开销,但这种方法的实用性如何.即使是N = 20的小尺寸,生成的线也超过7000个字符!我也注意到MATLAB编辑器由于长线而变得迟钝:)
此外,在N> 10左右后,优势迅速消失.我比较了嵌入式代码/显式写入与矩阵乘法,类似于@DennisJaheruddin建议的. results:
N=3: Elapsed time is 0.062295 seconds. % unroll Elapsed time is 1.117962 seconds. % mtimes N=12: Elapsed time is 1.024837 seconds. % unroll Elapsed time is 1.126147 seconds. % mtimes N=19: Elapsed time is 140.915138 seconds. % unroll Elapsed time is 1.305382 seconds. % mtimes
……对于展开的版本,它只会变得更糟.就像我之前说的那样,MATLAB被解释为解析代码的成本开始显示在如此庞大的文件中.
我看到它的方式,经过一百万次迭代,我们最多只获得1秒,我认为这并不能证明所有的麻烦和黑客,而是使用更加可读和简洁的v = x’* A * x.所以perhaps代码中还有其他地方可以改进,而不是专注于已经优化的操作,例如矩阵乘法.
Matrix multiplication in MATLAB is seriously fast(这是MATLAB最擅长的!).一旦你达到足够大的数据(multithreading开始),它真的很棒:
>> N=5000; x=rand(N,1); A=rand(N,N); >> tic, for i=1e4, v=x.'*A*x; end, toc Elapsed time is 0.021959 seconds.