1. 程式人生 > >cs231 卷積神經網路Convolutional Networks群組歸一化GN( Group Normalization)

cs231 卷積神經網路Convolutional Networks群組歸一化GN( Group Normalization)

cs231 Convolutional Networks Group Normalization:

 

 


def spatial_groupnorm_forward(x, gamma, beta, G, gn_param):
    """
    Computes the forward pass for spatial group normalization.
    In contrast to layer normalization, group normalization splits each entry 
    in the data into G contiguous pieces, which it then normalizes independently.
    Per feature shifting and scaling are then applied to the data, in a manner identical to that of batch normalization and layer normalization.

    Inputs:
    - x: Input data of shape (N, C, H, W)
    - gamma: Scale parameter, of shape (C,)
    - beta: Shift parameter, of shape (C,)
    - G: Integer mumber of groups to split into, should be a divisor of C
    - gn_param: Dictionary with the following keys:
      - eps: Constant for numeric stability

    Returns a tuple of:
    - out: Output data, of shape (N, C, H, W)
    - cache: Values needed for the backward pass
    """
    out, cache = None, None
    eps = gn_param.get('eps',1e-5)
    ###########################################################################
    # TODO: Implement the forward pass for spatial group normalization.       #
    # This will be extremely similar to the layer norm implementation.        #
    # In particular, think about how you could transform the matrix so that   #
    # the bulk of the code is similar to both train-time batch normalization  #
    # and layer normalization!                                                # 
    ###########################################################################
    #pass
    N, C, H, W = x.shape
    # 按分組g將大的立方體積木拆成 C/G個小積木體。
    #N, C, H, W = 2, 6, 4, 5;G = 2 ;這裡g為2個一組,拆成6/2=3組小立方體。
    x = x.reshape((N * G, C // G * H * W)) #(N, C, H, W)--->(N * G, C // G * H * W) 
    #接下來就可以將每1個小立方體作為一個Layer Norm的模組去處理。     
    x = x.T #(C // G * H * W,N * G)
    mean_x = np.mean(x,axis =0)
    var_x= np.var(x,axis = 0)
    inv_var_x = 1 / np.sqrt(var_x + eps)
    
    x_hat = (x - mean_x)/np.sqrt(var_x + eps) ##(C // G * H * W,N * G)
    x_hat = x_hat.T #(C // G * H * W,N * G)---->(N * G, C // G * H * W)
    
    
    x_hat = x_hat.reshape((N, C, H, W)) #(N * G, C // G * H * W)---->(N, C, H, W)
   

    
    out = gamma * x_hat + beta  
    cache =( x_hat,gamma,mean_x,inv_var_x, G)

    ###########################################################################
    #                             END OF YOUR CODE                            #
    ###########################################################################
    return out, cache


def spatial_groupnorm_backward(dout, cache):
    """
    Computes the backward pass for spatial group normalization.

    Inputs:
    - dout: Upstream derivatives, of shape (N, C, H, W)
    - cache: Values from the forward pass

    Returns a tuple of:
    - dx: Gradient with respect to inputs, of shape (N, C, H, W)
    - dgamma: Gradient with respect to scale parameter, of shape (C,)
    - dbeta: Gradient with respect to shift parameter, of shape (C,)
    """
    dx, dgamma, dbeta = None, None, None

    ###########################################################################
    # TODO: Implement the backward pass for spatial group normalization.      #
    # This will be extremely similar to the layer norm implementation.        #
    ###########################################################################
    pass

    x_hat,gamma,mean_x,inv_var_x, G = cache 
   
    #x_hat :(N, C, H, W)
    N, C, H, W = x_hat.shape
    # 在(N, H, W)維度上計算
    dgamma = np.sum(dout * x_hat, axis=(0, 2, 3), keepdims=True)
    dbeta = np.sum(dout, axis=(0, 2, 3), keepdims=True)
    #forward時拆分成幾個小立方體積來計算的,backward反向傳播時仍需分組拆成幾個小立方體計算。
   
    #dout :(N, C, H, W)--->(N * G, C // G * H * W) ---->(C // G * H * W, N * G)   
    dxhat = (dout * gamma).reshape((N * G, C // G * H * W)).T 
    
    #x_hat:(N, C, H, W)--->(N * G, C // G * H * W) ---->(C // G * H * W, N * G)
    x_hat = x_hat.reshape((N * G, C // G * H * W)).T    
    
    # d:   C // G * H * W 將每1個小立方體作為一個Layer Norm的反向backward模組去處理
    d = x_hat.shape[0]
    dx = (1. / d) * inv_var_x * (d * dxhat - np.sum(dxhat, axis=0) -
                                 x_hat * np.sum(dxhat * x_hat, axis=0))    
    
    
    dx = dx.T #(C // G * H * W, N * G) ----->(N * G, C // G * H * W) 
    # 將幾個小立方體再重新拼接成一個大立方體
    dx = dx.reshape((N, C, H, W)) #(N * G, C // G * H * W) --->(N, C, H, W)
    
   

    ###########################################################################
    #                             END OF YOUR CODE                            #
    ###########################################################################
    return dx, dgamma, dbeta

 

https://github.com/duanzhihua