https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android
此示例在以下位置使用.pb文件:[这是指向自动下载的文件的链接]
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
该示例显示如何将.pb文件加载到Tensorflow会话,并使用它来执行分类,但它没有(?)提到如何在图训练(例如,Python)后生成这样的.pb文件。
有没有什么例子如何做到这一点?
编辑:freeze_graph.py
脚本,它是TensorFlow仓库的一部分,现在作为一个工具,从现有的TensorFlow GraphDef和保存的检查点生成表示“冻结”训练模型的协议缓冲区。它使用下面描述的相同步骤,但它更容易使用。
目前这个过程没有很好的记录(并且需要细化),但是大致的步骤如下:
>构建和训练你的模型为tf.Graph称为g_1。
>获取每个变量的最终值,并将它们存储为numpy数组(使用Session.run()
)。
>在一个名为g_2的新tf.Graph中,使用在步骤2中获取的相应numpy数组的值为每个变量创建tf.constant()
张量。
>使用tf.import_graph_def()
将节点从g_1复制到g_2,并使用input_map参数将g_1中的每个变量替换为在步骤3中创建的对应tf.constant()张量。您还可以使用input_map指定新的输入张量例如用tf.placeholder()
替换input pipeline)。使用return_elements参数指定预测输出张量的名称。
>调用g_2.as_graph_def()获取图的协议缓冲区表示。
(注意:生成的图形在图形中将有额外的节点用于训练,虽然它不是公共API的一部分,但您可能希望使用内部graph_util.extract_sub_graph()
函数从图形中去除这些节点)。