Tensorflow 資料讀取 tf.data.Dataset API 相關介紹
阿新 • • 發佈:2019-02-17
介紹
tf.1.4及以後新出的tf.data.Dataset API 中,使用的資料讀取方式有點類似於pytorch中的Dataloader,大大簡化了資料讀取。下面是程式碼例項。
# coding=utf-8 import os import numpy as np import glob import tensorflow as tf import tensorflow.contrib.eager as tfe """資料讀取: Dataset API的介紹""" """ 1. Dataset API 支援tensorflow新出的Eager模式 Eager模式:迭代時可直接取值,而不是tensor。但在tf 1.4的標準版中,沒有eager模式,而是在nightly version 2. 通過Dataset類可以例項化出一個Iterator 3. Dataset 可以看成是相同型別元素的有序列表。這裡的元素可以是向量,字串,圖片,或者tuple,dict等 4. 從Dataset中取出元素: 需要例項化一個Interator,然後對Iterator進行迭代 5. Dataset支援一類特殊的操作: Transformation. 一個Dataset通過Transformation變成一個新的Dataset。 我們可以通過Transformation完成 資料變換, 打亂, 組成batch, 生成epoch 等操作 常用的Transformation: (1) map (2) batch (3) shuffle (4) repeat 6. dataset的建立方法: (1) tf.data.Dataset.from_tensor_slices (2) tf.data.TextLineDataset(): 輸入是一個檔案列表,輸出是一個dataset。dataset中的每一個元素就對應了檔案中的一行。 可以用這個函式來讀取csv檔案 (3) tf.data.FixedLengthRecordDataset(): 通常用來讀取以二進位制形式儲存的檔案,如CIFAR10資料集 (4) tf.data.TFRecordDataset(): 用來讀取tfrecord檔案,dataset中的每一個元素就是一個TFExample """ def eager_dataset(): """ 以eager模式讀取資料集 :return: """ dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) iterator = tfe.Iterator(dataset) for one_element in iterator: print(one_element) def non_eager_dataset(): """ 以非eager的方式讀取資料集 :return: """ # from_tensor_slices: 切分傳入Tensor的第一個維度,生成相應的dataset dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) """非eager模式""" # 建立一個iterator,且是一個one shot iterator,即只能從頭到尾讀取一次 iterator = dataset.make_one_shot_iterator() # 非Eager模式:one_element是一個tensor,而不是個實際的值 one_element = iterator.get_next() # with tf.Session() as sess: # for i in range(5): # # 如果一個dataset中的元素被讀取完了,再嘗試執行sess.run(one_element),會報tf.errors.OutOfRangeError的異常 # print(sess.run(one_element)) with tf.Session() as sess: try: while True: print(sess.run(one_element)) except tf.errors.OutOfRangeError: print('End') def non_eager_dataset_v2(): dataset = tf.data.Dataset.from_tensor_slices(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: while True: print(sess.run(one_element)) except tf.errors.OutOfRangeError: print('End') def non_eager_dataset_dict_classical(): """ 經典的影象處理類問題中,image 和 label 的組織形式: {'image': image_tensor, 'label': label_tensor} :return: """ # from_tensor_slices 會分別切分'a','b'中的數值,最終dataset中的一個元素類似於{'a': 1.0, 'b': dog}的形式 dataset = tf.data.Dataset.from_tensor_slices( {'a': np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 'b': ['dog', 'cat', 'pig', 'monkey', 'bear']}) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: while True: print(sess.run(one_element)) except tf.errors.OutOfRangeError: print('End') """Transformation 相關操作""" def map_fun(): dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(lambda x: x + 1) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: while True: print(sess.run(one_element)) except tf.errors.OutOfRangeError: print('End') def batch_fun(): dataset = tf.data.Dataset.from_tensor_slices(np.array(range(32))) # 注: batch 也支援不整除的操作 dataset = dataset.batch(5) dataset = dataset.shuffle(1000) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() cnt = 0 with tf.Session() as sess: try: while True: print('batch: {}, {}'.format(cnt, sess.run(one_element))) cnt += 1 except tf.errors.OutOfRangeError: print('End') def repeat_fun(): dataset = tf.data.Dataset.from_tensor_slices(np.array(range(10))) dataset = dataset.shuffle(1000) # repeat 的功能就是將整個資料集重複多次,主要用來處理機器學習中的epoch. dataset = dataset.repeat(3) iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: while True: print(sess.run(one_element)) except tf.errors.OutOfRangeError: print('End') """一個經典的讀取image和label的列子""" def parse_function(filename, label): image_string = tf.read_file(filename) # image_decoded = tf.image.decode_image(image_string, channels=3) image_decoded = tf.image.decode_jpeg(image_string) image_resized = tf.image.resize_images(image_decoded, size=(100, 100)) return image_resized, label def dataset_classical_example(): batch_size = 4 filenames_tmp = glob.glob(os.path.join('./data_samples', '*.{}'.format('jpg'))) filenames = tf.constant(filenames_tmp) labels = tf.constant(range(len(filenames_tmp))) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(parse_function) dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).repeat(3) iterator = dataset.make_one_shot_iterator() one_batch = iterator.get_next() with tf.Session() as sess: try: while True: batch_images, batch_labels = sess.run(one_batch) except tf.errors.OutOfRangeError: print('End') if __name__ == '__main__': # non_eager_dataset_dict_classical() # map_fun() # batch_fun() # repeat_fun() dataset_classical_example()