3378 views|1 replies

38

Posts

0

Resources
The OP
 

#AI Challenge Camp First Stop# MNIST handwriting recognition and model conversion based on pytorch [Copy link]

Handwriting recognition refers to the technology that computers receive and recognize information such as human handwriting from paper, photos, touch screens or other devices. It has a wide range of applications in our lives, such as document processing, mobile device input, personalized signatures, education and assistive technology, etc.

This tutorial provides a detailed explanation of the PyTorch code you are provided with, and walks you through training a model using a fully connected neural network to recognize handwritten digits in the MNIST dataset.

1. Dataset download and preprocessing

Download the mnist handwriting dataset and store it in the ./dataset folder

The data set needs to be divided into training set and test set

Why not use all the data for training?

In machine learning, it is common to partition a dataset into training and testing sets. While it may seem intuitive that you should use all of the data for training, partitioning your data into different subsets is critical to developing models efficiently for a number of reasons:

1. Evaluate model generalization ability:

The main purpose of the test set is to evaluate the generalization ability of the trained model . Generalization ability refers to the ability of the model to perform well on new data, which is crucial for practical applications.

If you use the entire dataset for training, including the data used for evaluation, the model may simply memorize the training examples and fail to generalize to new data. This phenomenon is called overfitting .

By using a separate test set, you can evaluate how well your model performs on data it has never seen, providing a more realistic estimate of performance.

2. Prevent overfitting:

Overfitting occurs when a model becomes too specific to the specific training data, capturing noise and irrelevant patterns rather than learning the underlying relationships in the data.

Splitting the data into training and testing sets helps prevent overfitting by introducing a validation set , which is used during the training process to monitor the performance of the model on unseen data.

When a model starts to overfit the training data, its performance on the validation set will usually start to degrade. This can serve as an early warning sign to stop training and prevent further overfitting.

3. Improve model selection efficiency:

During training, you might try different hyperparameters, model architectures, or training techniques.

Using a separate test set allows you to evaluate the performance of different models or training configurations on the same data, allowing you to objectively select the model that is best for the task.

4. Retain data for future evaluation:

In some cases, you may want to retain some of your data for future evaluation, such as comparing different models or techniques over time.

By keeping an isolated test set, you ensure that even if your overall dataset grows or changes over time, you always have consistent and unbiased data for future comparisons.


#训练集和验证集
training_set_full = datasets.MNIST('dataset/', train=True, transform=transforms.ToTensor(), download=True)
#测试集
test_set = datasets.MNIST('dataset/', train=False, transform=transforms.ToTensor(), download=True)

import torch
import torchvision
from torchvision import transforms

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.1307,), (0.3008,))  # 归一化
])

# 加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 将训练集划分为训练集和验证集
train_size = len(train_dataset)
val_size = int(0.1 * train_size)
train_idx = list(range(train_size))
val_idx = list(range(train_size - val_size, train_size))
random.shuffle(train_idx)

train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx[:-val_size])
val_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx[-val_size:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=val_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)

Visualization of the dataset

SAMPLE_IMG_ID = np.random.choice(len(training_set))

junk = plt.imshow(training_set[SAMPLE_IMG_ID][0].squeeze(0), cmap='gray')  # "squeeze" removes the first dimension (1,28,28) => (28,28)
junk = plt.title(training_set[SAMPLE_IMG_ID][1])

2. Model definition

Use a class to define a model, and define the definitions of each layer as some functions in the initialization. The parameter transfer between layers of the model is actually the nesting of a bunch of functions.

The entire model is equivalent to a nested function with many layers, but there are adjustable parameters in the function, which will affect the output of the neural network. The training process is to adjust these parameters so that the output of the neural network can be well approximated to the real data value. For example, this case is a ten-class problem. The neural network input is a set of pixels in the image, which is 28*28 in this case. The output is 10 probability values, which represent the probability of the predicted label, and their sum is 1.

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(14 * 14 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

3. Model Training

Define some hyperparameters, such as EPOCHSspecifying the number of training rounds and the EVALUATION_FREQnumber of batches at which the model is evaluated on the validation set.

The outer loop is used to iterate for the specified number of training epochs ( EPOCHS). The inner loop will process batches of data from the training set in each epoch.

# 循环遍历训练周期 (epoch)
for epoch in range(EPOCHS):
  print(f'第 {epoch + 1} 个 epoch')
  epoch_acc = []  # 保存每个 epoch 的准确率
  training_acc_checkpoint, training_loss_checkpoint = [], []  # 临时保存训练过程中的精度和损失,用于计算平均值

  # 遍历训练数据集中的每一个批次
  for batch_idx, (data, labels) in enumerate(training_loader):
    # 将数据和标签移动到指定设备 (CPU 或 GPU)
    data, labels = data.to(device), labels.to(device)

    # 评估模型,获得预测结果、准确率和损失值
    predictions, acc, loss = evaluate(model, loss_function, data, labels)
    training_acc_checkpoint.append(acc)
    epoch_acc.append(acc)
    training_loss_checkpoint.append(loss.item())

    # 反向传播计算梯度
    loss.backward()

    # 更新模型参数
    optimizer.step()

    # 清空梯度 (与 optimizer 相关)
    optimizer.zero_grad()  # 或者 model.zero_grad() (如果所有模型参数都在优化器中)

    # 周期性评估验证集
    if batch_idx % EVALUATION_FREQ == 0:
      # 计算并保存平均训练精度和损失
      training_acc_lst.append(np.mean(training_acc_checkpoint))
      training_loss_lst.append(np.mean(training_loss_checkpoint))
      # 清空临时保存的训练过程数据
      training_acc_checkpoint, training_loss_checkpoint = [], []

      # 评估验证集 (进入评估模式并关闭梯度追踪)
      model.train(mode=False)  # 进入评估模式 (参考链接: https://stackoverflow.com/a/55627781/900394)
      with torch.no_grad():  # 临时关闭梯度追踪
        validation_acc_checkpoint, validation_loss_checkpoint = [], []
        validation_predictions = []  # 保存用于之后展示结果的预测值
        for val_batch_idx, (val_data, val_labels) in enumerate(validation_loader):
          val_data, val_labels = val_data.to(device), val_labels.to(device)

          # 评估单个验证批次
          val_predictions, validation_acc, validation_loss = evaluate(model, loss_function, val_data, val_labels)

          validation_loss_checkpoint.append(validation_loss.item())
          validation_acc_checkpoint.append(validation_acc)
          validation_predictions.extend(val_predictions)  # 扩展 (append 会覆盖) 所有验证预测值

        # 计算并保存平均验证精度和损失
        validation_acc_lst.append(np.mean(validation_acc_checkpoint))
        validation_loss_lst.append(np.mean(validation_loss_checkpoint))

      # 重新进入训练模式
      model.train(mode=True)  

      # 打印当前 epoch 的训练和验证结果
      print(f'训练精度: {training_acc_lst[-1]:.2f}, 训练损失: {training_loss_lst[-1]:.2f}, 验证精度: {validation_acc_lst[-1]:.2f}, 验证损失: {validation_loss_lst[-1]:.2f}')

4. Model Saving

Model saving is the process of storing the trained model parameters and structure into a file so that it can be loaded and used later. This is essential for the following purposes:

  • Model deployment : Deploy the model to a production environment for prediction or inference.
  • Model Sharing : Share your models with others so they can use or tweak them.
  • Model Reproduction : Recreate a trained model at a later time.
#保存模型
torch.save(model.state_dict(), 'my_model.pth')

5. Model Conversion

What is ONNX?

ONNX (Open Neural Network Exchange) is an open format for representing deep learning models . It allows models to be stored in files and shared and converted between different frameworks (such as PyTorch, TensorFlow, MXNet, etc.).

The goal of ONNX is to simplify the deployment and sharing of deep learning models. It enables researchers and developers to focus on building models without worrying about how to deploy them to a specific platform or framework.

Benefits of ONNX

There are many advantages to using ONNX, including:

  • Portability : Models can be easily transferred between different frameworks without retraining or code modifications.
  • Reusability : Models can be stored and shared for use by others or fine-tuned in other projects.
  • Interoperability : Different frameworks can work together and use models from each other.
  • Simplified deployment : Models can be easily deployed to production environments without having to worry about the underlying framework.

So now we basically understand that onnx is an open format. Models of many frameworks can be converted to onnx, or exported from onnx. Converting to onnx can be transplanted.

Torch provides a method for model conversion

The dimension must be consistent with the model input. My custom model input is 28*28

#导出为onnx模型
dummy_input = torch.randn(1, 1, 784)

torch.onnx.export(model, dummy_input, "my_model.onnx", verbose=False)

Use Netron software to view the onnx neural network structure

Netron is a deep learning model visualization library that supports model storage files in the following formats:

  • ONNX (.onnx, .pb)
  • Keras (.h5, .keras)
  • CoreML (.mlmodel)
  • TensorFlow Lite (.tflite)

Netron does not support the model files exported by PyTorch through the torch.save method. Therefore, when saving the model in PyTorch, it needs to be exported as a model file in onnx format. The torch.onnx module can be used to achieve this goal.

The overall visualization effect is very good for analyzing the network structure and the parameters of each layer.

Netron has a web version and a software version. You can directly drag and drop the model to load and visualize it.

6. Model training and verification results

Since we did not use convolutional neural networks or recurrent neural networks, the highest accuracy of the model can only reach 97%. Some neural networks that use convolutional layers have stronger fitting capabilities and can reach 99% accuracy.

my_model.onnx

155.83 KB, downloads: 1

模型文件onnx

my_model.pth

157.34 KB, downloads: 0

模型文件pytorch

save2onnx.py

3.05 KB, downloads: 2

模型转换

train.py

9.23 KB, downloads: 2

模型与训练文件

This post is from Embedded System

Latest reply

You are awesome sir, it seems like the development board belongs to you.  Details Published on 2024-4-13 08:59

6824

Posts

11

Resources
2
 
You are awesome sir, it seems like the development board belongs to you.
This post is from Embedded System
 
 

Just looking around
Find a datasheet?

EEWorld Datasheet Technical Support

EEWorld
subscription
account

EEWorld
service
account

Automotive
development
circle

Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号
快速回复 返回顶部 Return list