1. 程式人生 > 其它 >MindSpore 初探, 使用LeNet訓練minist資料集

MindSpore 初探, 使用LeNet訓練minist資料集

如題所述,官網地址:

https://www.mindspore.cn/tutorial/zh-CN/r1.2/quick_start.html

資料集下載:

mkdir -p ./datasets/MNIST_Data/train ./datasets/MNIST_Data/test
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-labels-idx1-ubyte
wget -NP ./datasets/MNIST_Data/train https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/train-images-idx3-ubyte
wget 
-NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-labels-idx1-ubyte wget -NP ./datasets/MNIST_Data/test https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/mnist/t10k-images-idx3-ubyte tree ./datasets/MNIST_Data

個人整合後的程式碼:

#!/usr/bin python
#
encoding:UTF-8 """" 對輸入的超引數進行處理 """ import os import argparse """ 設定執行的背景context """ from mindspore import context """ 對資料集進行預處理 """ import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.vision.c_transforms as CV from mindspore.dataset.vision import
Inter from mindspore import dtype as mstype """ 構建神經網路 """ import mindspore.nn as nn from mindspore.common.initializer import Normal """ 訓練時對模型引數的儲存 """ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig """ 匯入模型訓練需要的庫 """ from mindspore.nn import Accuracy from mindspore.train.callback import LossMonitor from mindspore import Model parser = argparse.ArgumentParser(description='MindSpore LeNet Example') parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 為mindspore設定執行背景context context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): # 定義資料集 mnist_ds = ds.MnistDataset(data_path) resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # 定義所需要操作的map對映 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # 使用map對映函式,將資料操作應用到資料集 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 進行shuffle、batch、repeat操作 buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell): """ Lenet網路結構 """ def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() # 定義所需要的運算 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): # 使用定義好的運算構建前向網路 x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x # 例項化網路 net = LeNet5() # 定義損失函式 net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定義優化器 net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 設定模型儲存引數 # 每125steps儲存一次模型引數,最多保留15個檔案 config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15) # 應用模型儲存引數 ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode): """定義訓練的方法""" # 載入訓練資料集 ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size) model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path): """定義驗證的方法""" ds_eval = create_dataset(os.path.join(data_path, "test")) acc = model.eval(ds_eval, dataset_sink_mode=False) print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data" train_epoch = 1 dataset_size = 1 model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False) test_net(net, model, mnist_path)

訓練結果:

epoch: 1 step: 125, loss is 2.2982173
epoch: 1 step: 250, loss is 2.296105
epoch: 1 step: 375, loss is 2.3065567
epoch: 1 step: 500, loss is 2.3062077
epoch: 1 step: 625, loss is 2.3096561
epoch: 1 step: 750, loss is 2.2847052
epoch: 1 step: 875, loss is 2.284628
epoch: 1 step: 1000, loss is 1.8122461
epoch: 1 step: 1125, loss is 0.4140602
epoch: 1 step: 1250, loss is 0.25238502
epoch: 1 step: 1375, loss is 0.17819008
epoch: 1 step: 1500, loss is 0.3202765
epoch: 1 step: 1625, loss is 0.12312577
epoch: 1 step: 1750, loss is 0.11027573
epoch: 1 step: 1875, loss is 0.2680659
{'Accuracy': 0.9598357371794872}
View Code