当前位置 : 主页 > 编程语言 > java >

tf.argmax函数

来源:互联网 收集:自由互联 发布时间:2022-07-19
tf.argmax(A,axis=0) axis=0 求每列对大值的索引 axis=1 求每行最大值索引 似乎与通常的0为行,1为列正好相反!!! 代码示例 import tensorflow as tf import numpy as np A = np . zeros ([ 3 ]) B = np . zeros ([


tf.argmax(A,axis=0)

axis=0 求每列对大值的索引
axis=1 求每行最大值索引
似乎与通常的0为行,1为列正好相反!!!

代码示例

import tensorflow as tf
import numpy as np
A=np.zeros([3])
B=np.zeros([3,1])
C=np.array([[1,11,22,3],
[22,1,2,1],
[11,12,2,121]] )
init=tf.global_variables_initializer()
with tf.Session() as sess:
print(A)
print(B)
print("A的维度",tf.rank(A).eval())
print("B的维度",tf.rank(B).eval())
print("B的shape",tf.shape(B).eval())
print("C的每列最大值索引",tf.argmax(C,axis=0).eval())
print("C的每行最大值索引",tf.argmax(C,axis=1).eval())

output

[0. 0. 0.]
[[0.]
[0.]
[0.]]
A的维度 1
B的维度 2
B的shape [3 1]
C的每列最大值索引 [1 2 0 2]
C的每行最大值索引 [2 0 3]


网友评论