Keras自動下載的資料集/模型存放位置介紹
阿新 • • 發佈:2020-06-20
Mac
# 資料集
~/.keras/datasets/# 模型
~/.keras/models/
Linux
# 資料集
~/.keras/datasets/
Windows
# win10
C:\Users\user_name\.keras\datasets
補充知識:Keras_gan生成自己的資料,並儲存模型
我就廢話不多說了,大家還是直接看程式碼吧~
from __future__ import print_function,division from keras.datasets import mnist from keras.layers import Input,Dense,Reshape,Flatten,Dropout from keras.layers import BatchNormalization,Activation,ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D,Conv2D from keras.models import Sequential,Model from keras.optimizers import Adam import os import matplotlib.pyplot as plt import sys import numpy as np class GAN(): def __init__(self): self.img_rows = 3 self.img_cols = 60 self.channels = 1 self.img_shape = (self.img_rows,self.img_cols,self.channels) self.latent_dim = 100 optimizer = Adam(0.0002,0.5) # 構建和編譯判別器 self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy']) # 構建生成器 self.generator = self.build_generator() # 生成器輸入噪音,生成假的圖片 z = Input(shape=(self.latent_dim,)) img = self.generator(z) # 為了組合模型,只訓練生成器 self.discriminator.trainable = False # 判別器將生成的影象作為輸入並確定有效性 validity = self.discriminator(img) # The combined model (stacked generator and discriminator) # 訓練生成器騙過判別器 self.combined = Model(z,validity) self.combined.compile(loss='binary_crossentropy',optimizer=optimizer) def build_generator(self): model = Sequential() model.add(Dense(64,input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(128)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) #np.prod(self.img_shape)=3x60x1 model.add(Dense(np.prod(self.img_shape),activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=(self.latent_dim,)) img = model(noise) #輸入噪音,輸出圖片 return Model(noise,img) def build_discriminator(self): model = Sequential() model.add(Flatten(input_shape=self.img_shape)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(128)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(64)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1,activation='sigmoid')) model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img,validity) def train(self,epochs,batch_size=128,sample_interval=50): ############################################################ #自己資料集此部分需要更改 # 載入資料集 data = np.load('data/相對大小分叉.npy') data = data[:,:,0:60] # 歸一化到-1到1 data = data * 2 - 1 data = np.expand_dims(data,axis=3) ############################################################ # Adversarial ground truths valid = np.ones((batch_size,1)) fake = np.zeros((batch_size,1)) for epoch in range(epochs): # --------------------- # 訓練判別器 # --------------------- # data.shape[0]為資料集的數量,隨機生成batch_size個數量的隨機數,作為資料的索引 idx = np.random.randint(0,data.shape[0],batch_size) #從資料集隨機挑選batch_size個數據,作為一個批次訓練 imgs = data[idx] #噪音維度(batch_size,100) noise = np.random.normal(0,1,(batch_size,self.latent_dim)) # 由生成器根據噪音生成假的圖片 gen_imgs = self.generator.predict(noise) # 訓練判別器,判別器希望真實圖片,打上標籤1,假的圖片打上標籤0 d_loss_real = self.discriminator.train_on_batch(imgs,valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs,fake) d_loss = 0.5 * np.add(d_loss_real,d_loss_fake) # --------------------- # 訓練生成器 # --------------------- noise = np.random.normal(0,self.latent_dim)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise,valid) # 列印loss值 print ("%d [D loss: %f,acc.: %.2f%%] [G loss: %f]" % (epoch,d_loss[0],100*d_loss[1],g_loss)) # 沒sample_interval個epoch儲存一次生成圖片 if epoch % sample_interval == 0: self.sample_images(epoch) if not os.path.exists("keras_model"): os.makedirs("keras_model") self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True) self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True) def sample_images(self,epoch): r,c = 10,10 # 重新生成一批噪音,維度為(100,100) noise = np.random.normal(0,(r * c,self.latent_dim)) gen_imgs = self.generator.predict(noise) # 將生成的圖片重新歸整到0-1之間 gen = 0.5 * gen_imgs + 0.5 gen = gen.reshape(-1,3,60) fig,axs = plt.subplots(r,c) cnt = 0 for i in range(r): for j in range(c): xy = gen[cnt] for k in range(len(xy)): x = xy[k][0:30] y = xy[k][30:60] if k == 0: axs[i,j].plot(x,y,color='blue') if k == 1: axs[i,color='red') if k == 2: axs[i,color='green') plt.xlim(0.,1.) plt.ylim(0.,1.) plt.xticks(np.arange(0,0.1)) plt.xticks(np.arange(0,0.1)) axs[i,j].axis('off') cnt += 1 if not os.path.exists("keras_imgs"): os.makedirs("keras_imgs") fig.savefig("keras_imgs/%d.png" % epoch) plt.close() def test(self,gen_nums=100,save=False): self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True) self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True) noise = np.random.normal(0,(gen_nums,self.latent_dim)) gen = self.generator.predict(noise) gen = 0.5 * gen + 0.5 gen = gen.reshape(-1,60) print(gen.shape) ############################################################### #直接視覺化生成圖片 if save: for i in range(0,len(gen)): plt.figure(figsize=(128,128),dpi=1) plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300) plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300) plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300) plt.axis('off') plt.xlim(0.,1.) plt.ylim(0.,1.) plt.xticks(np.arange(0,0.1)) plt.yticks(np.arange(0,0.1)) if not os.path.exists("keras_gen"): os.makedirs("keras_gen") plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1) plt.close() ################################################################## #重整圖片到0-1 else: for i in range(len(gen)): plt.plot(gen[i][0][0:30],color='blue') plt.plot(gen[i][1][0:30],color='red') plt.plot(gen[i][2][0:30],color='green') plt.xlim(0.,0.1)) plt.xticks(np.arange(0,0.1)) plt.show() if __name__ == '__main__': gan = GAN() gan.train(epochs=300000,batch_size=32,sample_interval=2000) # gan.test(save=True)
以上這篇Keras自動下載的資料集/模型存放位置介紹就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。