当前位置 : 主页 > 大数据 > 区块链 >

protocol-buffers – 是否有一个关于如何生成保存训练有素的Tensorflow图的protobuf文件

来源:互联网 收集:自由互联 发布时间:2021-06-22
我正在查看Google的示例,说明如何在Android上部署和使用预先训练的Tensorflow图(模型),网址为: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android 此示例在以下位置使用.pb文
我正在查看Google的示例,说明如何在Android上部署和使用预先训练的Tensorflow图(模型),网址为:

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android

此示例在以下位置使用.pb文件:[这是指向自动下载的文件的链接]
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

该示例显示如何将.pb文件加载到Tensorflow会话,并使用它来执行分类,但它没有(?)提到如何在图训练(例如,Python)后生成这样的.pb文件。

有没有什么例子如何做到这一点?

编辑: freeze_graph.py脚本,它是TensorFlow仓库的一部分,现在作为一个工具,从现有的TensorFlow GraphDef和保存的检查点生成表示“冻结”训练模型的协议缓冲区。它使用下面描述的相同步骤,但它更容易使用。

目前这个过程没有很好的记录(并且需要细化),但是大致的步骤如下:

>构建和训练你的模型为tf.Graph称为g_1。
>获取每个变量的最终值,并将它们存储为numpy数组(使用Session.run())。
>在一个名为g_2的新tf.Graph中,使用在步骤2中获取的相应numpy数组的值为每个变量创建tf.constant()张量。
>使用tf.import_graph_def()将节点从g_1复制到g_2,并使用input_map参数将g_1中的每个变量替换为在步骤3中创建的对应tf.constant()张量。您还可以使用input_map指定新的输入张量例如用tf.placeholder()替换input pipeline)。使用return_elements参数指定预测输出张量的名称。
>调用g_2.as_graph_def()获取图的协议缓冲区表示。

(注意:生成的图形在图形中将有额外的节点用于训练,虽然它不是公共API的一部分,但您可能希望使用内部graph_util.extract_sub_graph()函数从图形中去除这些节点)。

网友评论