MindSpore 初探, 使用LeNet訓練minist資料集
阿新 • • 發佈:2021-07-05
如題所述,官網地址:
https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html
資料集下載:
mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte wget-NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte tree ./datasets/MNIST_Data
個人整合後的程式碼:
#!/usr/bin python #encoding:UTF-8 """" 對輸入的超引數進行處理 """ import os import argparse """ 設定執行的背景context """ from mindspore import context """ 對資料集進行預處理 """ import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.vision.c_transforms as CV from mindspore.dataset.vision importInter from mindspore import dtype as mstype """ 構建神經網路 """ import mindspore.nn as nn from mindspore.common.initializer import Normal """ 訓練時對模型引數的儲存 """ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig """ 匯入模型訓練需要的庫 """ from mindspore.nn import Accuracy from mindspore.train.callback import LossMonitor from mindspore import Model parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 為mindspore設定執行背景context context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): # 定義資料集 mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # 定義所需要操作的map對映 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # 使用map對映函式,將資料操作應用到資料集 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 進行shuffle、batch、repeat操作 buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """ Lenet網路結構 """ def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() # 定義所需要的運算 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): # 使用定義好的運算構建前向網路 x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x # 例項化網路 net = LeNet5() # 定義損失函式 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定義優化器 net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 設定模型儲存引數 # 每125steps儲存一次模型引數,最多保留15個檔案 config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15) # 應用模型儲存引數 ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode): """定義訓練的方法""" # 載入訓練資料集 ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path): """定義驗證的方法""" ds_eval = create_dataset(os.path.join(data_path, "test")) acc = model.eval(ds_eval, dataset_sink_mode=False) print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data" train_epoch = 1 dataset_size = 1 model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False) test_net(net, model, mnist_path)
訓練結果:
epoch: 1 step: 125, loss is 2.2982173 epoch: 1 step: 250, loss is 2.296105 epoch: 1 step: 375, loss is 2.3065567 epoch: 1 step: 500, loss is 2.3062077 epoch: 1 step: 625, loss is 2.3096561 epoch: 1 step: 750, loss is 2.2847052 epoch: 1 step: 875, loss is 2.284628 epoch: 1 step: 1000, loss is 1.8122461 epoch: 1 step: 1125, loss is 0.4140602 epoch: 1 step: 1250, loss is 0.25238502 epoch: 1 step: 1375, loss is 0.17819008 epoch: 1 step: 1500, loss is 0.3202765 epoch: 1 step: 1625, loss is 0.12312577 epoch: 1 step: 1750, loss is 0.11027573 epoch: 1 step: 1875, loss is 0.2680659 {'Accuracy': 0.9598357371794872}View Code