1. 程式人生 > >劃分資料集 匯入keras

劃分資料集 匯入keras

第一種用os.walk()函式 以遙感資料集 UCMerced 為例 

from contextlib import suppress
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import os
import warnings
from zipfile import ZipFile

from skimage.io import imread, imsave

from keras import applications
from keras import optimizers
from keras.models import Sequential
from keras.layers import Activation, Dense, Dropout, Flatten
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator

from sklearn.metrics import accuracy_score, classification_report

# # Data Preparation

# ### Download the [UC Merced Land Use dataset](http://vision.ucmerced.edu/datasets/landuse.html)

# In[2]: 建立資料夾

source_dir = (r'D:\tang\jiangxia\sun\jinzhou\data\erjifenlei')

# Set random seed for reproducibility
np.random.seed(8)

# ### Randomly assign each image to train, validate, or test folder, segregated by class name

# In[5]:

"""
Create image directory hierarchy that looks like this:
./data/flow/
            train/
                  agriculture/
                  airplane/
                  ...
            validate/
                     agriculture/
                     airplane/
                     ...
            test/
                 agriculture/
                 airplane/
                 ...
"""

# Collect class names from directory names in './data/UCMerced_LandUse/Images/'
# 收集類的名字存到列表 os.list()
class_names = os.listdir(source_dir)
print(class_names)
# Create path to image "flow" base directory
# 建立圖片集flow資料夾
flow_base = os.path.join(r'D:\tang\jiangxia\sun\jinzhou\data', 'flow19')

# Create pathnames to train/validate/test subdirectories
# 建立train/validate/test 3個子資料夾
target_dirs = {target: os.path.join(flow_base, target) for target in ['train',  'test']}

# 如果不存在資料夾 建立資料夾
if not os.path.isdir(flow_base):

    # Make new directories
    os.mkdir(flow_base)

    # 在train等資料夾建立 類名 子資料夾
    for target in ['train',  'test']:
        target_dir = os.path.join(flow_base, target)
        os.mkdir(target_dir)
        for class_name in class_names:
            class_subdir = os.path.join(target_dir, class_name)
            os.mkdir(class_subdir)

    # suppress low-contrast warning from skimage.io.imsave
    warnings.simplefilter('ignore', UserWarning)

    # Copy images from ./data/UCMerced_LandUse/Images to ./data/flow/<train, validate, test>

    for root, _, filenames in os.walk(source_dir):
        if filenames:
            class_name = os.path.basename(root)

            # Randomly shuffle filenames
            filenames = np.random.permutation(filenames)
            print(len(filenames))
            print(int( 0.9*len(filenames)))
            for target, count in [('train',int( 0.9*len(filenames))),  ('test', int( 0.1*len(filenames)))]:
                target_dir = os.path.join(flow_base, target, class_name)
                for filename in filenames[:count]:
                    filepath = os.path.join(root, filename)
                    image = imread(filepath)
                    basename, _ = os.path.splitext(filename)
                    # Convert TIF to PNG to work with Keras ImageDataGenerator.flow_from_directory
                    target_filename = os.path.join(target_dir, basename + '.png')
                    imsave(target_filename, image)

                filenames = filenames[count:]

    # Show future warnings during development
    warnings.resetwarnings()

第二種 用pandas 

file_names = glob.glob('NWPU-RESISC45/*/*')
file_names_df = [i.split('\\')[1:] for i in file_names]
file_names_df = pd.DataFrame(file_names_df, columns=['label','img_name'])
file_names_df['full_img_path'] = file_names
# load all images
images = np.array(file_names_df['full_img_path'].apply(lambda i: cv2.imread(i)).tolist()).astype(np.float32)
# preprocess the images
preprocessed_imgs = resnet_preprocess_input(images)
labels = np.array(file_names_df['label'])
#convert labels to one-hot encoded labels
le = LabelEncoder().fit(labels)
ohe = OneHotEncoder(sparse=False).fit(le.transform(labels).reshape(len(labels),1))

one_hot_labels = ohe.transform(le.transform(labels).reshape(len(labels),1))