我正在使用自定义图像集来使用Tensorflow API训练神经网络.在成功的训练过程之后,我得到这些检查点文件,其中包含不同训练var的值.我现在想从这些检查点文件中获得一个推理模型,我发现
import tensorflow as tf model_fn = 'export' graph = tf.Graph() sess = tf.InteractiveSession(graph=graph) with tf.gfile.FastGFile(model_fn, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) t_input = tf.placeholder(np.float32, name='input') imagenet_mean = 117.0 t_preprocessed = tf.expand_dims(t_input-imagenet_mean, 0) tf.import_graph_def(graph_def, {'input':t_preprocessed})
我收到此错误:
graph_def.ParseFromString(f.read())
self.MergeFromString(serialized)
raise message_mod.DecodeError(‘Unexpected end-group tag.’)
google.protobuf.message.DecodeError: Unexpected end-group tag.
该脚本需要一个协议缓冲区文件,我不确定我用于生成推理模型的script是否给了我原型缓冲区文件.
有人可以建议我做错了什么,或者有更好的方法来实现这一目标.我只想将张量生成的检查点文件转换为proto缓冲区.
谢谢
您运行的脚本的链接已损坏,但无论如何建议不要尝试从检查点生成推理模型,而是在训练程序结束时嵌入代码,这将导致“SavedModel”导出(这与检查点不同).请参见[1],特别是标题“建立已保存的模型”.请注意,保存模型构成多个文件,其中一个文件确实是一个协议缓冲区(它直接回答了我希望的问题);其他是可变文件和(可选)资产文件.
[1] https://www.tensorflow.org/programmers_guide/saved_model