# file_name: This is a place_holder that will contain the name of the files of the tfrecords. def load_sewa_data(file_name, batch_size): with tf.name_scope('sewa_tf_records'): dataset = tf.data.TFRecordDataset(file_name).map(_parse_sewa_example).batch(batch_size) iterator = dataset.make_initializable_iterator(shared_name='sewa_iterator') next_batch = iterator.get_next() names, detected, arousal, valence, liking, istalkings, images = next_batch print(names, detected, arousal, valence, liking, istalkings, images) return names, detected, arousal, valence, liking, istalkings, images, iterator
使用sess.run()在会话中运行名称后;我发现第一个68例是从Train_DE_01.tfrecords中获取的;然后,从相同的tfrecord中取出后续示例,直到消耗了Train_DE_01.tfrecords中的所有示例.
我尝试使用Dataset api的zip()函数和可重新初始化的迭代器,如下所示:
def load_devel_sewa_tfrecords(filenames_dev, test_batch_size): datasets_dev_iterators = [] with tf.name_scope('TFRecordsDevel'): for file_name in filenames_dev: dataset_dev = tf.data.TFRecordDataset(file_name).map(_parse_devel_function).batch(test_batch_size) datasets_dev_iterators.append(dataset_dev) dataset_dev_all = tf.data.Dataset.zip(tuple(datasets_dev_iterators)) return dataset_dev_all def load_train_sewa_tfrecords(filenames_train, train_batch_size): datasets_train_iterators = [] with tf.name_scope('TFRecordsTrain'): for file_name in filenames_train: dataset_train = tf.data.TFRecordDataset(file_name).map(_parse_train_function).batch(train_batch_size) datasets_train_iterators.append(dataset_train) dataset_train_all = tf.data.Dataset.zip(tuple(datasets_train_iterators)) return dataset_train_all def load_sewa_dataset(filenames_train, train_batch_size, filenames_dev, test_batch_size): dataset_train_all = load_train_sewa_tfrecords(filenames_train, train_batch_size) dataset_dev_all = load_devel_sewa_tfrecords(filenames_dev, test_batch_size) iterator = tf.data.Iterator.from_structure(dataset_train_all.output_types, dataset_train_all.output_shapes) training_init_op = iterator.make_initializer(dataset_train_all) validation_init_op = iterator.make_initializer(dataset_dev_all) with tf.name_scope('inputs'): next_batch = iterator.get_next(name='next_batch') names = [] detected = [] arousal = [] valence = [] liking = [] istalkings = [] images = [] # len(next_batch) is 34. # len(n) is 7. Since we are extracting: name, detected, arousal, valence, liking, istalking and images... # len(n[0 or 1 or 2 or ... or 6]) = is batch size. for n in next_batch: names.append(n[0]) detected.append(n[1]) arousal.append(n[2]) valence.append(n[3]) liking.append(n[4]) istalkings.append(n[5]) images.append(n[6]) names = tf.concat(names, axis=0, name='names') detected = tf.concat(detected, axis=0, name='detected') arousal = tf.concat(arousal, axis=0, name='arousal') valence = tf.concat(valence, axis=0, name='valence') liking = tf.concat(liking, axis=0, name='liking') istalkings = tf.concat(istalkings, axis=0, name='istalkings') images = tf.concat(images, axis=0, name='images') return names, detected, arousal, valence, liking, istalkings, images, training_init_op, validation_init_op
现在,如果我尝试以下内容:
sess = tf.Session() sess.run(training_init_op) print(sess.run(names))
我收到以下错误:
ValueError: The two structures don't have the same number of elements.
这是有道理的,因为培训文件的数量是34,而验证数据集的数量是14.
我想知道如何才能实现目标?
任何帮助深表感谢!!
这是我使用tf.cond找到的工作.为了从每个tfrecord中检索2个例子;我使用了tf.Dataset.data api的zip方法,如下所示:
def load_train_sewa_tfrecords(filenames_train, train_batch_size): datasets_train_iterators = [] with tf.name_scope('TFRecordsTrain'): for file_name in filenames_train: dataset_train = tf.data.TFRecordDataset(file_name).map(_parse_train_function).batch(train_batch_size) datasets_train_iterators.append(dataset_train) dataset_train_all = tf.data.Dataset.zip(tuple(datasets_train_iterators)) iterator_train_all = dataset_train_all.make_initializable_iterator() with tf.name_scope('inputs_train'): next_batch = iterator_train_all.get_next(name='next_batch') names = [] detected = [] arousal = [] valence = [] liking = [] istalkings = [] images = [] # len(next_batch) is 34. # len(n) is 7. Since we are extracting: name, detected, arousal, valence, liking, istalking and images... # len(n[0 or 1 or 2 or ... or 6]) = is batch size. for n in next_batch: names.append(n[0]) detected.append(n[1]) arousal.append(n[2]) valence.append(n[3]) liking.append(n[4]) istalkings.append(n[5]) images.append(n[6]) names = tf.concat(names, axis=0, name='names') detected = tf.concat(detected, axis=0, name='detected') arousal = tf.concat(arousal, axis=0, name='arousal') valence = tf.concat(valence, axis=0, name='valence') liking = tf.concat(liking, axis=0, name='liking') istalkings = tf.concat(istalkings, axis=0, name='istalkings') images = tf.concat(images, axis=0, name='images') return names, detected, arousal, valence, liking, istalkings, images, iterator_train_all
我将有一个类似的开发方法;或者我可以将传递参数更改为方法,以便我可以使用相同的方法两次…(不是问题).
然后:
names_dev, detected_dev, arousal_dev, valence_dev, liking_dev, istalkings_dev, images_dev, iterator_dev_all = \ load_devel_sewa_tfrecords(filenames_dev, sewa_batch_size) names_train, detected_train, arousal_train, valence_train, liking_train, istalkings_train, images_train, iterator_train_all = \ load_train_sewa_tfrecords(filenames_train, sewa_batch_size) images_train = pre_process_sewa_images(images_train) images_dev = pre_process_sewa_images(images_dev) def return_train_sewa(): return names_train, detected_train, arousal_train, valence_train, liking_train, istalkings_train, images_train def return_dev_sewa(): return names_dev, detected_dev, arousal_dev, valence_dev, liking_dev, istalkings_dev, images_dev names, detected, arousal, valence, liking, istalkings, images_sewa = tf.cond(phase_train, return_train_sewa, return_dev_sewa) sewa_inputs = [] sess = tf.Session() import numpy as np for e in range(epochs): sess.run(iterator_train_all.initializer) sess.run(iterator_dev_all.initializer) i = 0 total = 0 try: while True: i += 1 names_np, detected_np, arousal_np, valence_np, liking_np, istalkings_np = \ sess.run([names, detected, arousal, valence, liking, istalkings], feed_dict={phase_train: True}) total += np.shape(names_np)[0] print("total =", total, " | i =", i) except: print("end of train...") i_d = 0 total_d = 0 sess.run(iterator_train_all.initializer) sess.run(iterator_dev_all.initializer) try: while True: i_d += 1 names_np, detected_np, arousal_np, valence_np, liking_np, istalkings_np = \ sess.run([names, detected, arousal, valence, liking, istalkings], feed_dict={phase_train: False}) total_d += np.shape(names_np)[0] print("total_d =", total_d, " | i_d =", i_d) print(names_np) except: print("End of devel")
请注意,必须在sess.run([names ….])之前运行sess.run(iterator_train_all.initializer)和sess.run(iterator_dev_all.initializer),因为我猜tf.cond;将检索训练和验证示例,但是,tf.cond将仅基于phase_train place_holder返回其中一个,这将确定我们是否处于训练或测试模式.
证明:当我在load_devel_sewa_tfrecords下插入names = tf.Print(input _ = [names],data = [names],message =’dev names’);在返回之前;我有:
dev names[\'Devel_01\' \'Devel_01\' \'Devel_02\'...]
在评估训练数据集时,在console.i.e中打印出来; tensorflow同时评估了devel数据集;但是tf.cond超出了与训练数据集相关的tfrecords.
希望这个答案有帮助!!