swin-transformer-pytorch實現 程式碼解讀
阿新 • • 發佈:2021-11-18
最近打kaggle要用
swin-transformer-pytorch實現
model.py
""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Code/weights from https://github.com/microsoft/Swin-Transformer """ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np from typing import Optional def drop_path_f(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path_f(x, self.drop_prob, self.training) def window_partition(x, window_size: int): """ 將feature map按照window_size劃分成一個個沒有重疊的window Args: x: (B, H, W, C) window_size (int): window size(M) Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C] # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C] windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size: int, H: int, W: int): """ 將一個個window還原成一個feature map Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size(M) H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C] x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C] # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C] x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = (patch_size, patch_size) self.patch_size = patch_size self.in_chans = in_c self.embed_dim = embed_dim self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): _, _, H, W = x.shape # padding # 如果輸入圖片的H,W不是patch_size的整數倍,需要進行padding pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0) if pad_input: # to pad the last 3 dimensions, # (W_left, W_right, H_top,H_bottom, C_front, C_back) x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1], 0, self.patch_size[0] - H % self.patch_size[0], 0, 0)) # 下采樣patch_size倍 x = self.proj(x) _, _, H, W = x.shape # flatten: [B, C, H, W] -> [B, C, HW] # transpose: [B, C, HW] -> [B, HW, C] x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, H, W class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """ x: B, H*W, C """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding # 如果輸入feature map的H,W不是2的整數倍,需要進行padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: # to pad the last 3 dimensions, starting from the last dimension and moving forward. # (C_front, C_back, W_left, W_right, H_top, H_bottom) # 注意這裡的Tensor通道是[B, H, W, C],所以會和官方文件有些不同 x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C] x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C] x = self.norm(x) x = self.reduction(x) # [B, H/2*W/2, 2*C] return x class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # [Mh, Mw] self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH] # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw] coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] # [2, Mh*Mw, 1] - [2, 1, Mh*Mw] relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw] relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2] relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw] self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: x: input features with shape of (num_windows*B, Mh*Mw, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ # [batch_size*num_windows, Mh*Mw, total_embed_dim] B_, N, C = x.shape # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim] # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head] # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw] # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw] q = q * self.scale attn = (q @ k.transpose(-2, -1)) # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH] relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw] attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: # mask: [nW, Mh*Mw, Mh*Mw] nW = mask.shape[0] # num_windows # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw] # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head] # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim] x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, attn_mask): H, W = self.H, self.W B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size # 把feature map給pad到window size的整數倍 pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C] x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C] # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C] # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C] shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C] # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: # 把前面pad的資料移除掉 x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.dim = dim self.depth = depth self.window_size = window_size self.use_checkpoint = use_checkpoint self.shift_size = window_size // 2 # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else self.shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def create_mask(self, x, H, W): # calculate attention mask for SW-MSA # 保證Hp和Wp是window_size的整數倍 Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size # 擁有和feature map一樣的通道排列順序,方便後續window_partition img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1] h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1] mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw] attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] # [nW, Mh*Mw, Mh*Mw] attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def forward(self, x, H, W): attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw] for blk in self.blocks: blk.H, blk.W = H, W if not torch.jit.is_scripting() and self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x = self.downsample(x, H, W) H, W = (H + 1) // 2, (W + 1) // 2 return x, H, W class SwinTransformer(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__(self, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm # stage4輸出特徵矩陣的channels self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # 注意這裡構建的stage和論文圖中有些差異 # 這裡的stage不包含該stage的patch_merging層,包含的是下個stage的 layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layers) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): # x: [B, L, C] x, H, W = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x, H, W = layer(x, H, W) x = self.norm(x) # [B, L, C] x = self.avgpool(x.transpose(1, 2)) # [B, C, 1] x = torch.flatten(x, 1) x = self.head(x) return x def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs): # trained ImageNet-1K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), num_classes=num_classes, **kwargs) return model def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs): # trained ImageNet-1K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), num_classes=num_classes, **kwargs) return model def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs): # trained ImageNet-1K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), num_classes=num_classes, **kwargs) return model def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs): # trained ImageNet-1K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), num_classes=num_classes, **kwargs) return model def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): # trained ImageNet-22K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), num_classes=num_classes, **kwargs) return model def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): # trained ImageNet-22K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), num_classes=num_classes, **kwargs) return model def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): # trained ImageNet-22K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), num_classes=num_classes, **kwargs) return model def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): # trained ImageNet-22K # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth model = SwinTransformer(in_chans=3, patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), num_classes=num_classes, **kwargs) return model
predict.py
import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import swin_tiny_patch4_window7_224 as create_model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") img_size = 224 data_transform = transforms.Compose( [transforms.Resize(int(img_size * 1.14)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # load image img_path = "../tulip.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r") class_indict = json.load(json_file) # create model model = create_model(num_classes=5).to(device) # load model weights model_weight_path = "./weights/model-9.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) print(print_res) plt.show() if __name__ == '__main__': main()
train.py
import os import argparse import torch import torch.optim as optim from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from my_dataset import MyDataSet from model import swin_tiny_patch4_window7_224 as create_model from utils import read_split_data, train_one_epoch, evaluate def main(args): device = torch.device(args.device if torch.cuda.is_available() else "cpu") if os.path.exists("./weights") is False: os.makedirs("./weights") tb_writer = SummaryWriter() train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) img_size = 224 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} # 例項化訓練資料集 train_dataset = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) # 例項化驗證資料集 val_dataset = MyDataSet(images_path=val_images_path, images_class=val_images_label, transform=data_transform["val"]) batch_size = args.batch_size nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=nw, collate_fn=train_dataset.collate_fn) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=nw, collate_fn=val_dataset.collate_fn) model = create_model(num_classes=args.num_classes).to(device) if args.weights != "": assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) weights_dict = torch.load(args.weights, map_location=device)["model"] # 刪除有關分類類別的權重 for k in list(weights_dict.keys()): if "head" in k: del weights_dict[k] print(model.load_state_dict(weights_dict, strict=False)) if args.freeze_layers: for name, para in model.named_parameters(): # 除head外,其他權重全部凍結 if "head" not in name: para.requires_grad_(False) else: print("training {}".format(name)) pg = [p for p in model.parameters() if p.requires_grad] optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2) for epoch in range(args.epochs): # train train_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch) # validate val_loss, val_acc = evaluate(model=model, data_loader=val_loader, device=device, epoch=epoch) tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] tb_writer.add_scalar(tags[0], train_loss, epoch) tb_writer.add_scalar(tags[1], train_acc, epoch) tb_writer.add_scalar(tags[2], val_loss, epoch) tb_writer.add_scalar(tags[3], val_acc, epoch) tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--num_classes', type=int, default=5) parser.add_argument('--epochs', type=int, default=10) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.0001) # 資料集所在根目錄 # http://download.tensorflow.org/example_images/flower_photos.tgz parser.add_argument('--data-path', type=str, default="/data/flower_photos") # 預訓練權重路徑,如果不想載入就設定為空字元 parser.add_argument('--weights', type=str, default='./swin_tiny_patch4_window7_224.pth', help='initial weights path') # 是否凍結權重 parser.add_argument('--freeze-layers', type=bool, default=False) parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') opt = parser.parse_args() main(opt)
utils.py
import os
import sys
import json
import pickle
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保證隨機結果可復現
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍歷資料夾,一個資料夾對應一個類別
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保證順序一致
flower_class.sort()
# 生成類別名稱以及對應的數字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
train_images_path = [] # 儲存訓練集的所有圖片路徑
train_images_label = [] # 儲存訓練集圖片對應索引資訊
val_images_path = [] # 儲存驗證集的所有圖片路徑
val_images_label = [] # 儲存驗證集圖片對應索引資訊
every_class_num = [] # 儲存每個類別的樣本總數
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支援的檔案字尾型別
# 遍歷每個資料夾下的檔案
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍歷獲取supported支援的所有檔案路徑
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 獲取該類別對應的索引
image_class = class_indices[cla]
# 記錄該類別的樣本數量
every_class_num.append(len(images))
# 按比例隨機取樣驗證樣本
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path: # 如果該路徑在取樣的驗證集樣本中則存入驗證集
val_images_path.append(img_path)
val_images_label.append(image_class)
else: # 否則存入訓練集
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
plot_image = False
if plot_image:
# 繪製每種類別個數柱狀圖
plt.bar(range(len(flower_class)), every_class_num, align='center')
# 將橫座標0,1,2,3,4替換為相應的類別名稱
plt.xticks(range(len(flower_class)), flower_class)
# 在柱狀圖上新增數值標籤
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
# 設定x座標
plt.xlabel('image class')
# 設定y座標
plt.ylabel('number of images')
# 設定柱狀圖的標題
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
def plot_data_loader_image(data_loader):
batch_size = data_loader.batch_size
plot_num = min(batch_size, 4)
json_path = './class_indices.json'
assert os.path.exists(json_path), json_path + " does not exist."
json_file = open(json_path, 'r')
class_indices = json.load(json_file)
for data in data_loader:
images, labels = data
for i in range(plot_num):
# [C, H, W] -> [H, W, C]
img = images[i].numpy().transpose(1, 2, 0)
# 反Normalize操作
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
label = labels[i].item()
plt.subplot(1, plot_num, i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([]) # 去掉x軸的刻度
plt.yticks([]) # 去掉y軸的刻度
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info: list, file_name: str):
with open(file_name, 'wb') as f:
pickle.dump(list_info, f)
def read_pickle(file_name: str) -> list:
with open(file_name, 'rb') as f:
info_list = pickle.load(f)
return info_list
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
accu_loss = torch.zeros(1).to(device) # 累計損失
accu_num = torch.zeros(1).to(device) # 累計預測正確的樣本數
optimizer.zero_grad()
sample_num = 0
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
loss_function = torch.nn.CrossEntropyLoss()
model.eval()
accu_num = torch.zeros(1).to(device) # 累計預測正確的樣本數
accu_loss = torch.zeros(1).to(device) # 累計損失
sample_num = 0
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
accu_loss += loss
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定義資料集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB為彩色圖片,L為灰度圖片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方實現的default_collate可以參考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
本文來自部落格園,作者:甫生,轉載請註明原文連結:https://www.cnblogs.com/fusheng-rextimmy/p/15570094.html