1. 程式人生 > >tfgan折騰筆記(二):核心函式詳述——gan_model族

tfgan折騰筆記(二):核心函式詳述——gan_model族

定義model的函式有:

1.gan_model

函式原型:

def gan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True)

引數:

generator_fn:預先定義好的生成器網路的函式名稱。預先定義好的生成器函式的輸入引數應該是接下來要說明的第四個引數generator_inputs,生成網路的返回值是網路的輸出(因為是GAN,所以生成器的輸出一般是一幅機器生成的影象)。

discriminator_fn:預先定義好的判別器網路的函式名稱。預先定義好的判別器函式的輸入引數有兩個:第一個是“真實資料(影象)”/“機器生成的影象(generator_fn的返回值)”;第二個是生成器的輸入,即此函式的第四個引數(在普通的gan當中,判別器只需要第一個引數。即使不需要第二個引數,也必須顯式地定義出第二個引數,只不過定義了之後在判別器函式中可以不使用)。判別器的返回值必須在負無窮到正無窮之間([-inf, +inf])。

real_data:真實影象。一般傳入真實影象batch化後的引用。

generator_inputs:生成器的輸入。對於vallina gan,是tensor型別的噪聲。除此之外,如果是c-gan,還可以傳入一個list或tuple作為引數(在下方的“其他說明“裡詳細說明c-gan(conditional-gan)的情況)。

generator_scope:傳入這個引數可以定義生成器內參數的變數名稱空間(variable_scope)。預設為"Generator"。

discriminator_scope:傳入這個引數可以定義判別器內參數的變數名稱空間(variable_scope)。預設為"Discriminator"。

check_shapes:如果為真,將檢查生成器生成的資料與真實資料是否有相同的shape。如果為假,則跳過檢查。

返回值:

返回一個“GANModel 命名管道”。實際上就是一個由生成器函式、判別器函式、生成的資料、變數空間等東西組成的一個List。這個返回值不需要我們寫程式的時候用,就不過多解釋了(具體用法見本系列上一篇文件:傳送門)。

函式內部實現:

generator_fn和discriminator_fn在gan_model函式裡這樣呼叫:

# 由機器生成資料
generated_data = generator_fn(generator_inputs)

# 判別器判斷機器生成圖片的真實性
discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs)

# 判別器判斷真實圖片的真實性
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

 

其他說明:

  • gan_model支援conditional-gan。若需要訓練c-gan,要通過generator_inputs額外傳入標籤資訊。如:generator_inputs=(noise, one_hot_label)。同時,判別器網路與生成器網路應該按照c-gan論文中的模型重新定義。
  • real_data一般為一個next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

2.infogan_model

函式原型:

def infogan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    unstructured_generator_inputs,
    structured_generator_inputs,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator')

引數:

generator_fn:預先定義好的生成器網路的函式名稱。預先定義好的生成器函式的輸入引數應該是接下來要說明的unstructrued_generator_inputs與structured_generator_inputs共同組成的列表,列表中的每一項是一個Tensor,生成網路的返回值是生成器的輸出。

discriminator_fn:預先定義好的判別器網路的函式名稱。預先定義好的判別器函式的輸入引數應該有兩個:第一個是“真實資料(影象)”/“機器生成的影象(generator_fn的返回值)”;第二個是生成器的輸入,即(unstructrued_generator_inputs與structured_generator_inputs共同組成的列表)。預先定義好的判別器函式的輸出應是一個二維Tuple。Tuple的第一維是生成器網路輸出層的logits,範圍在[-inf, +inf]。Tuple的第二維是分佈的列表:此分佈的第i個列表元素代表的是第i個structure noise 的分佈。

real_data:真實影象。一般傳入真實影象batch化後的引用。

unstructured_generator_inputs:Tensor的列表。表示非結構化的noise或條件。

structured_generator_inputs:Tensor的列表。這些Tensor必須與識別器具有較高的相互資訊。

generator_scope:傳入這個引數可以定義生成器內參數的變數名稱空間(variable_scope)。預設為"Generator"。

discriminator_scope:傳入這個引數可以定義判別器內參數的變數名稱空間(variable_scope)。預設為"Discriminator"。

返回值:

返回一個“InfoGANModel 命名管道”。同“GANModel 命名管道”一樣,我們無需關心管道中的具體內容。

函式內部實現:

生成器的輸入這樣定義:

generator_inputs = (unstructured_generator_inputs + structured_generator_inputs)

 

生成器和判別器這樣呼叫:

# 由機器生成資料
generated_data = generator_fn(generator_inputs)

# 判別器判斷機器生成圖片的真實性
dis_gen_outputs, predicted_distributions = discriminator_fn(generated_data, generator_inputs)

# 判別器判斷真實圖片的真實性
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

 

其他說明:

  • 關於生成器和判別器網路模型的搭建,請參照Info-GAN的論文。
  • real_data一般為一個next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

3.acgan_model:

函式原型:

def acgan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # Real data and conditioning.
    real_data,
    generator_inputs,
    one_hot_labels,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    # Options.
    check_shapes=True)

 

引數:

與gan_model中的引數基本一致,除了:

discriminator_fn:預定義的判別器函式應當返回一個二維Tuple。第一維是網路輸出層的real或者fake的logits;第二維是分類器的logits。他們兩個的範圍都應該是[-inf, +inf]。

one_hot_labels:對應於一個batch影象的one_hot_label。

返回值:

返回“AcGANModel 命名管道”。同“GANModel 命名管道”一樣,我們無需關心管道中的具體內容。

函式內部實現:

生成器和判別器這樣呼叫:

# 由機器生成資料
generated_data = generator_fn(generator_inputs)

# 判別器判斷機器生成圖片的真實性
(discriminator_gen_outputs, discriminator_gen_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(generated_data, generator_inputs))

# 判別器判斷真實圖片的真實性
(discriminator_real_outputs, discriminator_real_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(real_data, generator_inputs))

 

其他說明:

  • one_hot_labels在此函式內部沒有被使用,而是直接通過命名管道(返回值)傳遞給gan_loss函式(下一篇詳細說明)。
  • one_hot_labels與real_data均為batch。

4.cyclegan_model:

函式原型:

def cyclegan_model(
    # Lambdas defining models.
    generator_fn,
    discriminator_fn,
    # data X and Y.
    data_x,
    data_y,
    # Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    model_x2y_scope='ModelX2Y',
    model_y2x_scope='ModelY2X',
    # Options.
    check_shapes=True)

 

引數:

generator_fn:預先定義好的生成器函式。此生成器的輸入有一個引數,與gan_model的generator_fn一樣。返回值為生成器網路的輸出。

discriminator_fn:預先定義好的判別器函式。與gan_model的discriminator_fn定義一樣。

data_x:x域的真實資料。

data_y:y域的真實資料。

generator_scope:與gan_model的generator_scope意義一樣。

discriminator_scope:與gan_model的discriminator_scope意義一樣。

model_x2y_scope:x->y轉換過程的variable_scope。

model_y2x_scope:y->x轉換過程的variable_scope。

check_shapes:如果為真,將檢查生成器生成的資料與真實資料是否有相同的shape。如果為假,則跳過檢查。

返回值:

返回“CycleGANModel 名稱空間”。

函式內部實現:

此函式實際上呼叫了gan_model函式,如下所示:

# Create models.
  def _define_partial_model(input_data, output_data):    # 內部函式定義
    return gan_model(
        generator_fn=generator_fn,
        discriminator_fn=discriminator_fn,
        real_data=output_data,
        generator_inputs=input_data,
        generator_scope=generator_scope,
        discriminator_scope=discriminator_scope,
        check_shapes=check_shapes)

  with tf.compat.v1.variable_scope(model_x2y_scope):
    model_x2y = _define_partial_model(data_x, data_y)
  with tf.compat.v1.variable_scope(model_y2x_scope):
    model_y2x = _define_partial_model(data_y, data_x)

  with tf.compat.v1.variable_scope(model_y2x.generator_scope, reuse=True):
    reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
  with tf.compat.v1.variable_scope(model_x2y.generator_scope, reuse=True):
    reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)

  return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
                                   reconstructed_y)

 

其他說明:

5.stargan_model

函式原型:

def stargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator')

 

引數:

generator_fn:預先定義好的函式的函式名稱。函式的輸入有兩個,應分別為:input、target,返回值是根據inputs和targets由機器生成的影象。inputs的形狀應該是(batch, height, width, channel),targets的形狀是(batch, num_domain)。返回值有和inputs相同的形狀。

discriminator_fn:預先定義好的函式的函式名稱。此函式的輸入有兩個,分別為input和num_domain。返回值是一個Tuple:(`source_prediction`, `domain_prediction`)。`source_prediction`表示預測的影象(真實或生成的)真實度,“ domain_prediction”代表判別器對域分類的預測(真實度)。 `source_prediction`的形狀是(batch), `domain_prediction`具有形狀(batch,num_domains)。

input_data:Tensor或Tensor組成的列表。代表真實輸入的圖片。形狀是(batch, height, width, channel)。

input_data_domain_label:Tensor或Tensor組成的列表。形狀為(batch, num_domains)。表示真實資料相對應的代表域的標籤。

generator_scope:與gan_model的此引數意義相同。

discriminator_scope:與gan_model的此引數意義相同。

返回值:

返回“StarGANModel 名稱空間”。

函式內部實現:

 函式內部重要程式碼如下:

  # Transform input_data to random target domains.
  with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
    generated_data_domain_target = generate_stargan_random_domain_target(
        batch_size, num_domains)
    generated_data = generator_fn(input_data, generated_data_domain_target)

  # Transform generated_data back to the original input_data domain.
  with tf.compat.v1.variable_scope(generator_scope, reuse=True):
    reconstructed_data = generator_fn(generated_data, input_data_domain_label)

  # Predict source and domain for the generated_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
    disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
        generated_data, num_domains)

  # Predict source and domain for the input_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
    disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
        input_data, num_domains)

 

其他說明: