1. 程式人生 > 程式設計 >如何從csv檔案構建Tensorflow的資料集

如何從csv檔案構建Tensorflow的資料集

從csv檔案構建Tensorflow的資料集

當我們有一系列CSV檔案,如何構建Tensorflow的資料集呢?

基本步驟

  1. 獲得一組CSV檔案的路徑
  2. 將這組檔名,轉成檔名對應的dataset => file_dataset
  3. 根據file_dataset中的每個檔名,讀取檔案內容 生成一個內容的dataset => content_dataset
  4. 這樣的多個content_dataset,拼接起來,形成一整個dataset
  5. 因為讀出來的每條記錄都是string型別, 所以還需要對每條記錄做decode

存在一個這樣的變數train_filenames

pprint.pprint(train_filenames)
#	['generate_csv\\train_00.csv',#	 'generate_csv\\train_01.csv',#	 'generate_csv\\train_02.csv',#	 'generate_csv\\train_03.csv',#	 'generate_csv\\train_04.csv',#	 'generate_csv\\train_05.csv',#	 'generate_csv\\train_06.csv',#	 'generate_csv\\train_07.csv',#	 'generate_csv\\train_08.csv',#	 'generate_csv\\train_09.csv',#	 'generate_csv\\train_10.csv',#	 'generate_csv\\train_11.csv',#	 'generate_csv\\train_12.csv',#	 'generate_csv\\train_13.csv',#	 'generate_csv\\train_14.csv',#	 'generate_csv\\train_15.csv',#	 'generate_csv\\train_16.csv',#	 'generate_csv\\train_17.csv',#	 'generate_csv\\train_18.csv',#	 'generate_csv\\train_19.csv']

接著,我們用提前定義好的API構建檔名資料集file_dataset

filename_dataset = tf.data.Dataset.list_files(train_filenames)
for filename in filename_dataset:
  print(filename)
#tf.Tensor(b'generate_csv\\train_09.csv',shape=(),dtype=string)
#tf.Tensor(b'generate_csv\\train_19.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_03.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_01.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_14.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_17.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_15.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_06.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_05.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_07.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_11.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_02.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_12.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_13.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_10.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_16.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_18.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_00.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_04.csv',dtype=string)
#tf.Tensor(b'generate_csv\\train_08.csv',dtype=string)

第三步, 根據每個檔名,去讀取檔案裡面的內容

dataset = filename_dataset.interleave(
  lambda filename: tf.data.TextLineDataset(filename).skip(1),cycle_length=5
)

for line in dataset.take(3):
  print(line)

#tf.Tensor(b'0.46908349737250216,1.8718193706428006,0.13936365871212536,-0.011055733363841472,-0.6349261778219746,-0.036732316700563934,1.0259470089944995,-1.319095600336748,2.171',dtype=string)
#tf.Tensor(b'-1.102093775650278,1.313248890578542,-0.7212003024178728,-0.14707856286537277,0.34720121604358517,0.0965085401826684,-0.74698820254838,0.6810563907247876,1.428',dtype=string)
#tf.Tensor(b'-0.8901003715328659,0.9142699762469286,-0.1851678950250224,-0.12947457252940406,0.5958187430364827,-0.021255215877779534,0.7914317693724252,-0.45618713536506217,0.75',dtype=string)

interleave的作用可以類比map,對每個元素應用操作,然後還能把結果合起來。
因此,有了interleave,我們就把第三四步,一起完成了
之所以skip(1),是因為這個csv第一行是header.
cycle_length是並行化構建資料集的執行緒數

好,第五步,解析每條記錄

def parse_csv_line(line,n_fields=9):
  defaults = [tf.constant(np.nan)] * n_fields
  parsed_fields = tf.io.decode_csv(line,record_defaults=defaults)
  x = tf.stack(parsed_fields[:-1])
  y = tf.stack(parsed_fields[-1:])
  return x,y

parse_csv_line('1.2286258796252256,-1.0806245954111382,0.4444161407754224,-0.0352172575329119,0.9740347681426992,-0.003516079473801425,-0.8126524696425611,0.865609068204283,2.803',9)

#(<tf.Tensor: shape=(8,),dtype=float32,numpy= array([ 1.2286259,-1.0806246,0.44441614,-0.03521726,0.9740348,-0.00351608,-0.81265247,0.86560905],dtype=float32)>,<tf.Tensor: shape=(1,numpy=array([2.803],dtype=float32)>)

最後,將每條記錄都應用這個方法,就完成了構建。

dataset = dataset.map(parse_csv_line)

完整程式碼

def csv_2_dataset(filenames,n_readers_thread = 5,batch_size = 32,n_parse_thread = 5,shuffle_buffer_size = 10000):
  
  dataset = tf.data.Dataset.list_files(filenames)
  dataset = dataset.repeat()
  dataset = dataset.interleave(
    lambda filename: tf.data.TextLineDataset(filename).skip(1),cycle_length=n_readers_thread
  )
  dataset.shuffle(shuffle_buffer_size)
  dataset = dataset.map(parse_csv_line,num_parallel_calls = n_parse_thread)
  dataset = dataset.batch(batch_size)
  return dataset

如何使用

train_dataset = csv_2_dataset(train_filenames,batch_size=32)
valid_dataset = csv_2_dataset(valid_filenames,batch_size=32)

model = ...

model.fit(train_set,validation_data=valid_set,steps_per_epoch = 11610 // 32,validation_steps = 3870 // 32,epochs=100,callbacks=callbacks)

這裡的11610 和 3870是什麼?

這是train_dataset 和 valid_dataset中資料的數量,需要在訓練中手動指定每個batch中參與訓練的資料的多少。

model.evaluate(test_set,steps=5160//32)

同理,測試的時候,使用這樣的資料集,也需要手動指定。
5160是測試資料集的總量。

以上就是如何從csv檔案構建Tensorflow的資料集的詳細內容,更多關於csv檔案構建Tensorflow的資料集的資料請關注我們其它相關文章!