当前位置 : 主页 > 编程语言 > java >

tensorflow 多维矩阵相乘 多维tensor相乘

来源:互联网 收集:自由互联 发布时间:2022-07-20
import tensorflow as tf sess = tf.Session() left = tf.ones(shape=[16,20]) right = tf.ones(shape=[20,100]) result = tf.einsum('in,nd-id', left, right) print(sess.run(tf.shape(result))) left = tf.ones(shape=[10,16,20]) right = tf.ones(shape=[


import tensorflow as tf

sess = tf.Session()

left = tf.ones(shape=[16,20])
right = tf.ones(shape=[20,100])

result = tf.einsum('in,nd->id', left, right)
print(sess.run(tf.shape(result)))

left = tf.ones(shape=[10,16,20])
right = tf.ones(shape=[20,100])

result = tf.einsum('ibn,nd->ibd', left, right)
print(sess.run(tf.shape(result)))

left = tf.ones(shape=[10,16,20])
right = tf.ones(shape=[20,24,100])

result = tf.einsum('ibh,hnd->ibnd', left, right)
print(sess.run(tf.shape(result)))

print结果
[ 16 100]
[ 10 16 100]
[ 10 16 24 100]


网友评论