当前位置 : 主页 > 编程语言 > python >

debug tensorflow的seq2seq的attention_decoder方法

来源:互联网 收集:自由互联 发布时间:2022-07-20
写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现 import tensorflow as tf from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.contrib.legacy_seq2seq.python.o


写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现

import tensorflow as tf
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
with tf.Session() as sess:
batch_size = 16
step1 = 20
step2 = 10
input_size = 50
output_size = 40
gru_hidden = 30
cell_fn = lambda: rnn_cell.GRUCell(gru_hidden)
cell = cell_fn()
inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1
enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32)
attn_states = tf.concat([
tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
], 1)
dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2

dec, mem = seq2seq_lib.attention_decoder(
dec_inp, enc_state, attn_states, cell_fn(), output_size=7)
sess.run([tf.global_variables_initializer()])
res = sess.run(dec)
print(len(res))
print(res[0].shape)
res = sess.run([mem])
print(len(res))
print(res[0].shape)

改编自​​https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py​​


【转自:香港高防 http://www.558idc.com/stgf.html转载请说明出处】
上一篇:Scikit Learn CountVectorizer 入门实例
下一篇:没有了
网友评论