flax的學習01 基本用法
安裝jax jaxlib
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
安裝flax
pip install flax pip install --upgrade git+https://github.com/google/flax.git #但是這個我沒有成功
文件
文件地址https://flax.readthedocs.io/en/latest/index.html
flax莫的引數和初始化,看兩個模型中的程式碼
class TokenLearnerModule(nn.Module): """TokenLearner module. This is the module used for the experiments in the paper. Attributes: num_tokens: Number of tokens. """ num_tokens: int use_sum_pooling: bool = True @nn.compact def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: """Applies learnable tokenization to the 2D inputs. Args: inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`. Returns: Output of shape `[bs, n_token, c]`. """ if inputs.ndim == 3: n, hw, c = inputs.shape h = int(math.sqrt(hw)) inputs = jnp.reshape(inputs, [n, h, h, c])#保證形狀時這個樣子的 if h * h != hw: raise ValueError('Only square inputs supported.') feature_shape = inputs.shape selected = inputs selected = nn.LayerNorm()(selected) for _ in range(3):#這裡就是向前傳報了 selected = nn.Conv( self.num_tokens, kernel_size=(3, 3), strides=(1, 1), padding='SAME', use_bias=False)(selected) # Shape: [bs, h, w, n_token]. selected = nn.gelu(selected) selected = nn.Conv( self.num_tokens, kernel_size=(3, 3), strides=(1, 1), padding='SAME', use_bias=False)(selected) # Shape: [bs, h, w, n_token]. selected = jnp.reshape( selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1 ]) # Shape: [bs, h*w, n_token]. selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w]. selected = nn.sigmoid(selected)[..., None] # Shape: [bs, n_token, h*w, 1]. feat = inputs feat = jnp.reshape( feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1 ])[:, None, ...] # Shape: [bs, 1, h*w, c]. if self.use_sum_pooling: inputs = jnp.sum(feat * selected, axis=2) else: inputs = jnp.mean(feat * selected, axis=2) return inputs
from typing import Sequence import numpy as np import jax import jax.numpy as jnp import flax.linen as nn class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, x): for feat in self.features[:-1]: x = nn.relu(nn.Dense(feat)(x)) x = nn.Dense(self.features[-1])(x) return x model = MLP([12, 8, 4]) batch = jnp.ones((32, 10)) variables = model.init(jax.random.PRNGKey(0), batch) output = model.apply(variables, batch)
這裡包含了幾個jax的知識點,但是jax個人也不是很熟悉,所以再去查詢jax的文件https://jax.readthedocs.io/en/latest/index.html
找到random的相關內容,PRNGKey是seudorandom number generators keys 方法的縮寫,把他認為是生成兩個隨機的數就可以了,需要注意的是他一次返回兩個數值,比如key=random.PRNGKey(0)
的返回值是[0,0],這個東西是一個隨機的key,好像在jax中沒有像常見的那種隨機數就是直接給個數,jax的隨機數都是要提供一個key的,這個key就是用這個方法所生成的,此時就可以用random.uniform(key)
來得到一個服從均勻分佈的數字。
同樣的所有的隨機數都需要這樣的一個key,但是不需要重複的進行呼叫random.PRNGKey,可以使用jax.random.split(key,num=2)
來吧這個隨機鍵(暫且那麼叫)拆分成更多的子健,每一個子健都可以像原來的那樣使用,需要的子健的數量在num引數中給出,此時接受資料的方法就和元組類似k1,k2,k3 = jax.random.split(key,num=3)
引數
引數需要進行初始化,對於習慣了pytorch中的再init中先寫模型的定義再向前傳播這個無疑是很讓人看不懂的,在文件中已經寫明瞭,Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.
具體的引數矩陣的形狀是交給模型去自動推理的,自己不需要計算,需要提供一個假輸入(假輸出),模型會自動推算模型的各個矩陣的形狀
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input這裡就是假定的輸入
params = model.init(key2, x) # Initialization call 自動計算引數的大小
jax.tree_map(lambda x: x.shape, params) # Checking output shapes 和python原生的map類似這裡的作用主要是檢視形狀
model.init_with_output
就是用輸出去計算引數的形狀的
向前傳播
向前傳播也和torch有很大的不同,model.apply(params, x)
是jax的向前傳播語句
向後傳播
對於樣本 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\),目標是找的最優的引數\(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\)使得輸出在最小二乘法的損失下有最小值。
準備資料
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5
# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
使用jax的向前傳播
# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
# Define the squared loss for a single pair (x,y)
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y-pred, y-pred) / 2.0
# Vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
梯度下降
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)