如何從csv檔案構建Tensorflow的資料集
阿新 • • 發佈:2020-09-22
從csv檔案構建Tensorflow的資料集
當我們有一系列CSV檔案,如何構建Tensorflow的資料集呢?
基本步驟
- 獲得一組CSV檔案的路徑
- 將這組檔名,轉成檔名對應的dataset => file_dataset
- 根據file_dataset中的每個檔名,讀取檔案內容 生成一個內容的dataset => content_dataset
- 這樣的多個content_dataset,拼接起來,形成一整個dataset
- 因為讀出來的每條記錄都是string型別, 所以還需要對每條記錄做decode
存在一個這樣的變數train_filenames
pprint.pprint(train_filenames) # ['generate_csv\\train_00.csv',# 'generate_csv\\train_01.csv',# 'generate_csv\\train_02.csv',# 'generate_csv\\train_03.csv',# 'generate_csv\\train_04.csv',# 'generate_csv\\train_05.csv',# 'generate_csv\\train_06.csv',# 'generate_csv\\train_07.csv',# 'generate_csv\\train_08.csv',# 'generate_csv\\train_09.csv',# 'generate_csv\\train_10.csv',# 'generate_csv\\train_11.csv',# 'generate_csv\\train_12.csv',# 'generate_csv\\train_13.csv',# 'generate_csv\\train_14.csv',# 'generate_csv\\train_15.csv',# 'generate_csv\\train_16.csv',# 'generate_csv\\train_17.csv',# 'generate_csv\\train_18.csv',# 'generate_csv\\train_19.csv']
接著,我們用提前定義好的API構建檔名資料集file_dataset
filename_dataset = tf.data.Dataset.list_files(train_filenames) for filename in filename_dataset: print(filename) #tf.Tensor(b'generate_csv\\train_09.csv',shape=(),dtype=string) #tf.Tensor(b'generate_csv\\train_19.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_03.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_01.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_14.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_17.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_15.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_06.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_05.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_07.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_11.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_02.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_12.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_13.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_10.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_16.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_18.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_00.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_04.csv',dtype=string) #tf.Tensor(b'generate_csv\\train_08.csv',dtype=string)
第三步, 根據每個檔名,去讀取檔案裡面的內容
dataset = filename_dataset.interleave( lambda filename: tf.data.TextLineDataset(filename).skip(1),cycle_length=5 ) for line in dataset.take(3): print(line) #tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171',dtype=string) #tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428',dtype=string) #tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75',dtype=string)
interleave的作用可以類比map,對每個元素應用操作,然後還能把結果合起來。
因此,有了interleave,我們就把第三四步,一起完成了
之所以skip(1),是因為這個csv第一行是header.
cycle_length是並行化構建資料集的執行緒數
好,第五步,解析每條記錄
def parse_csv_line(line,n_fields=9): defaults = [tf.constant(np.nan)] * n_fields parsed_fields = tf.io.decode_csv(line,record_defaults=defaults) x = tf.stack(parsed_fields[:-1]) y = tf.stack(parsed_fields[-1:]) return x,y parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803',9) #(<tf.Tensor: shape=(8,),dtype=float32,numpy= array([ 1.2286259,-1.0806246,0.44441614,-0.03521726,0.9740348,-0.00351608,-0.81265247,0.86560905],dtype=float32)>,<tf.Tensor: shape=(1,numpy=array([2.803],dtype=float32)>)
最後,將每條記錄都應用這個方法,就完成了構建。
dataset = dataset.map(parse_csv_line)
完整程式碼
def csv_2_dataset(filenames,n_readers_thread = 5,batch_size = 32,n_parse_thread = 5,shuffle_buffer_size = 10000): dataset = tf.data.Dataset.list_files(filenames) dataset = dataset.repeat() dataset = dataset.interleave( lambda filename: tf.data.TextLineDataset(filename).skip(1),cycle_length=n_readers_thread ) dataset.shuffle(shuffle_buffer_size) dataset = dataset.map(parse_csv_line,num_parallel_calls = n_parse_thread) dataset = dataset.batch(batch_size) return dataset
如何使用
train_dataset = csv_2_dataset(train_filenames,batch_size=32) valid_dataset = csv_2_dataset(valid_filenames,batch_size=32) model = ... model.fit(train_set,validation_data=valid_set,steps_per_epoch = 11610 // 32,validation_steps = 3870 // 32,epochs=100,callbacks=callbacks)
這裡的11610 和 3870是什麼?
這是train_dataset 和 valid_dataset中資料的數量,需要在訓練中手動指定每個batch中參與訓練的資料的多少。
model.evaluate(test_set,steps=5160//32)
同理,測試的時候,使用這樣的資料集,也需要手動指定。
5160是測試資料集的總量。
以上就是如何從csv檔案構建Tensorflow的資料集的詳細內容,更多關於csv檔案構建Tensorflow的資料集的資料請關注我們其它相關文章!