1. 程式人生 > 其它 >flax的學習01 基本用法

flax的學習01 基本用法

安裝jax jaxlib

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

安裝flax

pip install flax
pip install --upgrade git+https://github.com/google/flax.git #但是這個我沒有成功

文件

文件地址https://flax.readthedocs.io/en/latest/index.html
flax莫的引數和初始化,看兩個模型中的程式碼

class TokenLearnerModule(nn.Module):
  """TokenLearner module.

  This is the module used for the experiments in the paper.

  Attributes:
    num_tokens: Number of tokens.
  """
  num_tokens: int
  use_sum_pooling: bool = True

  @nn.compact
  def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
    """Applies learnable tokenization to the 2D inputs.

    Args:
      inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`.

    Returns:
      Output of shape `[bs, n_token, c]`.
    """
    if inputs.ndim == 3:
      n, hw, c = inputs.shape
      h = int(math.sqrt(hw))
      inputs = jnp.reshape(inputs, [n, h, h, c])#保證形狀時這個樣子的

      if h * h != hw:
        raise ValueError('Only square inputs supported.')

    feature_shape = inputs.shape

    selected = inputs
    selected = nn.LayerNorm()(selected)

    for _ in range(3):#這裡就是向前傳報了
      selected = nn.Conv(
          self.num_tokens,
          kernel_size=(3, 3),
          strides=(1, 1),
          padding='SAME',
          use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

      selected = nn.gelu(selected)

    selected = nn.Conv(
        self.num_tokens,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding='SAME',
        use_bias=False)(selected)  # Shape: [bs, h, w, n_token].

    selected = jnp.reshape(
        selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
                  ])  # Shape: [bs, h*w, n_token].
    selected = jnp.transpose(selected, [0, 2, 1])  # Shape: [bs, n_token, h*w].
    selected = nn.sigmoid(selected)[..., None]  # Shape: [bs, n_token, h*w, 1].

    feat = inputs
    feat = jnp.reshape(
        feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
              ])[:, None, ...]  # Shape: [bs, 1, h*w, c].

    if self.use_sum_pooling:
      inputs = jnp.sum(feat * selected, axis=2)
    else:
      inputs = jnp.mean(feat * selected, axis=2)

    return inputs
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

這裡包含了幾個jax的知識點,但是jax個人也不是很熟悉,所以再去查詢jax的文件https://jax.readthedocs.io/en/latest/index.html
找到random的相關內容,PRNGKey是seudorandom number generators keys 方法的縮寫,把他認為是生成兩個隨機的數就可以了,需要注意的是他一次返回兩個數值,比如key=random.PRNGKey(0)的返回值是[0,0],這個東西是一個隨機的key,好像在jax中沒有像常見的那種隨機數就是直接給個數,jax的隨機數都是要提供一個key的,這個key就是用這個方法所生成的,此時就可以用random.uniform(key)來得到一個服從均勻分佈的數字。
同樣的所有的隨機數都需要這樣的一個key,但是不需要重複的進行呼叫random.PRNGKey,可以使用jax.random.split(key,num=2)來吧這個隨機鍵(暫且那麼叫)拆分成更多的子健,每一個子健都可以像原來的那樣使用,需要的子健的數量在num引數中給出,此時接受資料的方法就和元組類似k1,k2,k3 = jax.random.split(key,num=3)

引數

引數需要進行初始化,對於習慣了pytorch中的再init中先寫模型的定義再向前傳播這個無疑是很讓人看不懂的,在文件中已經寫明瞭,Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
具體的引數矩陣的形狀是交給模型去自動推理的,自己不需要計算,需要提供一個假輸入(假輸出),模型會自動推算模型的各個矩陣的形狀

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input這裡就是假定的輸入
params = model.init(key2, x) # Initialization call  自動計算引數的大小
jax.tree_map(lambda x: x.shape, params) # Checking output shapes  和python原生的map類似這裡的作用主要是檢視形狀

model.init_with_output就是用輸出去計算引數的形狀的

向前傳播

向前傳播也和torch有很大的不同,model.apply(params, x)是jax的向前傳播語句

向後傳播

對於樣本 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),目標是找的最優的引數\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\)使得輸出在最小二乘法的損失下有最小值。

準備資料

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

使用jax的向前傳播

# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

梯度下降

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)