tensorflow2--generate訓練多輸入,多輸出
阿新 • • 發佈:2022-05-09
np.random.seed(1) def combineGenerator(x_l,x_a,x_v,batch_size): index = 0 while True: start = index % x_l.shape[0] end = start + batch_size ################ l_mask ################ bat_l = x_l[start:end] l_mask1 = get_mask(bat_l, mask_rate=0.1, mask_lenth=1) l_mask2= get_mask(bat_l, mask_rate=0.1, mask_lenth=2) l_mask = l_mask1 + l_mask2 l_mask = np.where(l_mask >= 1, 1, 0) l_output = l_mask * bat_l l_output = l_output.astype(np.float32) ################ a_mask ################ bat_a = x_a[start:end] a_mask1= get_mask(bat_a, mask_rate=0.1, mask_lenth=1) a_mask2 = get_mask(bat_a, mask_rate=0.1, mask_lenth=2) a_mask = a_mask1 + a_mask2 a_mask = np.where(a_mask >= 1, 1, 0) a_output = a_mask * bat_a a_output = a_output.astype(np.float32)################ v_mask ################ bat_v = x_v[start:end] v_mask1 = get_mask(bat_v, mask_rate=0.1, mask_lenth=1) v_mask2 = get_mask(bat_v, mask_rate=0.1, mask_lenth=2) v_mask = v_mask1 + v_mask2 v_mask = np.where(v_mask >= 1, 1, 0) v_output = v_mask * bat_v v_output = v_output.astype(np.float32) # bat_y = y[start:end] * (1-mask_mat[:,:,:,np.newaxis]) index += batch_size # yield (bat_x_mask, bat_y_mask) yield [bat_l,bat_a,bat_v,l_mask,a_mask,v_mask],{'mult__model':l_output,'mult__model_1':a_output,'mult__model_2':v_output} train_generator = combineGenerator(l_train,a_train,v_train,batch_size=bat) test_generator = combineGenerator(l_val,a_val,v_val, batch_size=bat) v_input = tf.keras.layers.Input(shape=(500,35)) a_input = tf.keras.layers.Input(shape=(500,74)) l_input = tf.keras.layers.Input(shape=(50,300)) v_mask_input = tf.keras.layers.Input(shape=(500, 1)) a_mask_input = tf.keras.layers.Input(shape=(500, 1)) l_mask_input = tf.keras.layers.Input(shape=(50, 1)) out_l,out_a,out_v,out = model(l_input,a_input,v_input,l_mask =l_mask_input,a_mask =a_mask_input,v_mask =v_mask_input) pre_model = keras.models.Model(inputs = [l_input,a_input,v_input,l_mask_input,a_mask_input,v_mask_input],outputs = [out_l,out_a,out_v]) opt = tf.keras.optimizers.Adam(lr=lr_rate,clipvalue=1.) history2 = LossHistory_early_stop(which_test, epochs, bat, lr_rate,) # early_stoping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True, mode='min') early_stoping = EarlyStopping(monitor='val_weighted_accuracy', patience=patience, restore_best_weights=True, mode='max') pre_model.compile(loss={'mult__model':tf.losses.MSE,'mult__model_1':tf.losses.MSE,'mult__model_2':tf.losses.MSE}, optimizer=opt, loss_weights=[100,0.5,30]) # my_model.compile(loss=tf.losses.MSE, optimizer=opt, metrics=[weighted_accuracy]) # my_model.compile(loss=tf.losses.MSE, optimizer=opt) pre_model.fit(train_generator,validation_data=test_generator,steps_per_epoch=v_train.shape[0]//bat, validation_steps=v_val.shape[0]//bat, epochs=epochs, batch_size=bat,callbacks=[early_stoping,history2])
搜尋
複製