当前位置 : 主页 > 编程语言 > python >

python tensorflow框架

来源:互联网 收集:自由互联 发布时间:2022-06-30
二进制读取案例 import tensorflow as tf import os os . environ [ 'TF_CPP_MIN_LOG_LEVEL' ] = '2' class Cifar ( object ): def __init__ ( self ): self . height = 32 self . weight = 32 self . channels = 3 #图像像素 self . image_bytes

二进制读取案例

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

class Cifar(object):

def __init__(self):
self.height = 32
self.weight = 32
self.channels = 3

#图像像素
self.image_bytes = self.height * self.weight * self.channels
#图像的标签
self.label_bytes = 1
#一个样本
self.all_bytes = self.image_bytes + self.label_bytes

def read_and_decoded(self,file_list):

# 1.构建文件队列
file_queue = tf.train.string_input_producer(file_list)
# 2.读取与解码
#读取
reader = tf.FixedLengthRecordReader(self.all_bytes)
key,value = reader.read(file_queue)
print(key)
print(value)
#解码
decoded_value = tf.decode_raw(value, tf.uint8)
print(decoded_value)
#目标值切片
label = tf.slice(decoded_value, [0], [self.label_bytes])
image = tf.slice(decoded_value, [self.label_bytes], [self.image_bytes])
print(label)
print(image)
#恢复张量shape,先channels、height、weight
image_reshape = tf.reshape(image, shape=[self.channels, self.height, self.weight])
#装置,将原本读取的channels、height、width--装置为tensorflow支持的排列-----> height、width、channels
image_transpose = tf.transpose(image_reshape, [1, 2, 0])
print(image_transpose)
#调整图像类型,方便矩阵计算
image_casted = tf.cast(image_transpose, dtype=tf.float32)
print(image_casted)
# 3.批处理
label_batch, image_batch = tf.train.batch([label, image_casted], batch_size=100, num_threads=1, capacity=100)
print(label_batch)
print(image_batch)
with tf.Session() as sess:
#开启线程
#线程协调
coords = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coords)

new_key,\\
new_value, \\
new_decoded_value, \\
new_label, \\
new_image, \\
new_image_reshape,\\
new_image_transpose,\\
new_image_casted,\\
new_label_batch,\\
new_image_batch= sess.run([key,
value,
decoded_value,
label,
image,
image_reshape,
image_transpose,
image_casted,
label_batch,
image_batch])
print(new_key)
print(new_value)
print(new_decoded_value)
print(new_label)
print(new_image)
print(new_image_reshape)
coords.request_stop()
coords.join(threads)
print(new_image_transpose)
print(new_image_casted)
print(new_label_batch)
print(new_image_batch)
return None

if __name__ == '__main__':
filename = os.listdir("./datasources/datasets/cifar-10-batches-py")
file_list = [os.path.join("./datasources/datasets/cifar-10-batches-py/", file) for file in filename ]

cifar = Cifar()
cifar.read_and_decoded(file_list)

存储——TFRecords

  • 是一种二进制文件,能够更好的利用内存,根方便复制和移动,不需要单独的标签文件
  • 使用步骤:
  • 获取数据
  • 将数据填入到Example协议内存块(protocol buffter)
  • 将协议内存块序列化为字符串,通过tf.python_io.TFRecordWriter写入到TFRecords文件
    • 文件格式*.tfrecords

    • Example内部结构
    • options具体要看值的类型
    • 例子:
    """
    序列化数据,使用TFRecords文件存储
    """
    def save_to_tfrecord(self, image_batch, label_batch ):
    with tf.python_io.TFRecordWriter("cifar.tfrecords") as wirter:
    #因为有100个样本
    for i in range(100):
    image = image_batch[i].tostring()
    label = label_batch[i][0]
    example = tf.train.Example(features = tf.train.Features(feature = {
    "image":tf.train.Feature(bytes_list = tf.train.BytesList(value=[image])),
    "label":tf.train.Feature(int64_list = tf.train.Int64List(value=[label])),
    }))
    # 将序列化后的example写入文件
    wirter.write(example.SerializeToString())
    return None

    if __name__ == '__main__':
    filename = os.listdir("./datasources/datasets/cifar-10-batches-py")
    file_list = [os.path.join("./datasources/datasets/cifar-10-batches-py/", file) for file in filename ]

    cifar = Cifar()
    image, label = cifar.read_and_decoded(file_list)
    cifar.save_to_tfrecord(image, label)


    上一篇:python神经网络理论
    下一篇:没有了
    网友评论