写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现 import tensorflow as tf from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.contrib.legacy_seq2seq.python.o
写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现
import tensorflow as tffrom tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
with tf.Session() as sess:
batch_size = 16
step1 = 20
step2 = 10
input_size = 50
output_size = 40
gru_hidden = 30
cell_fn = lambda: rnn_cell.GRUCell(gru_hidden)
cell = cell_fn()
inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat([
tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2
dec, mem = seq2seq_lib.attention_decoder(
dec_inp, enc_state, attn_states, cell_fn(), output_size=7)
sess.run([tf.global_variables_initializer()])
res = sess.run(dec)
print(len(res))
print(res[0].shape)
res = sess.run([mem])
print(len(res))
print(res[0].shape)
改编自https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py