1. 程式人生 > >批量讀取資料next_batch()的理解

批量讀取資料next_batch()的理解

批量讀取資料

# 隨機取batch_size個訓練樣本  
import numpy as np
#train_data訓練集特徵,train_target訓練集對應的標籤,batch_size
def next_batch(train_data, train_target, batch_size):  
    #打亂資料集
    index = [ i for i in range(0,len(train_target)) ]  
    np.random.shuffle(index);  
    #建立batch_data與batch_target的空列表
    batch_data = []; 
    batch_target = [];  
    #向空列表加入訓練集及標籤
    for i in range(0,batch_size):  
        batch_data.append(train_data[index[i]]);  
        batch_target.append(train_target[index[i]])  
    return batch_data, batch_target #返回