3867 views|10 replies

364

Posts

0

Resources
The OP
 

#AI Challenge Camp First Stop#pytorch training MNIST dataset to achieve handwritten digit recognition [Copy link]

邀请:@chenzhufly   @skywalker_lee   @wsdymg   @bigbat   参与回复

This post was last edited by LitchiCheng on 2024-4-18 22:28

Download the MNIST dataset

# MNIST数据集,用于训练,一次抓60 size
        self._train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data/', train=True, download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.1307,), (0.3081,))
                                    ])),
            batch_size=60, shuffle=True)
        # 用于测试,一次抓500 size
        self._test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data/', train=False, download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.1307,), (0.3081,))
                                    ])),
            batch_size=500, shuffle=True)

Editing a network

# 连接序列
        self._conv1_layer = nn.Sequential(
            # 卷积
            nn.Conv2d(1,15,5),
            # 激活函数
            nn.ReLU(),
            # 最大池化,减少特征量,选特征最大的数,是一种下采样
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self._conv2_layer = nn.Sequential(
            nn.Conv2d(15,30,5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self._full_layer = nn.Sequential(
            # 卷积层都是四维张量,展平为二维张量给连接层用
            nn.Flatten(),
            nn.Linear(in_features=480, out_features=60),
            nn.ReLU(),
            nn.Linear(in_features=60, out_features=10),
        )

Determine whether GPU training is possible

 if torch.cuda.is_available():
            print("Use CUDA training!")
            self._device = torch.device("cuda")
        else:
            print("Use CPU training!")
            self._device = torch.device("cpu")

train

def train(self):
        loss_d = []
        for epoch in range(1, self._epochs + 1):
            self._cnn.train(mode=True)
            for idx, (train_img, train_label) in enumerate(self._train_loader):
                # 复制到device中
                train_img = train_img.to(self._device)
                train_label = train_label.to(self._device)
                outputs = self._cnn(train_img)
                # 清除梯度
                self._optim.zero_grad()
                loss = self._loss_func(outputs, train_label)
                # 反向传播  
                loss.backward()
                # 更新权重
                self._optim.step()
                # print('Train epoch {}: loss: {:.6f}'.format(epoch,loss.item()))
                loss_d.append(loss.item())
        plt.plot(range(0,len(loss_d)),loss_d)
        plt.show()

Train loss distribution

Test loss and accuracy

Identification results

Saved pth and onnx models

    def savePthModel(self, pth_name:str):
        torch.save(self._cnn.state_dict(), pth_name)

    def saveOnnxModel(self, onnx_name:str):
        input = torch.randn(1,1,28,28)
        torch.onnx.export(self._cnn, input, onnx_name, verbose=True)

Complete code

import torch
import torch.nn as nn
import torchvision.datasets
import matplotlib.pyplot as plt
import numpy as np

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 连接序列
        self._conv1_layer = nn.Sequential(
            # 卷积
            nn.Conv2d(1,15,5),
            # 激活函数
            nn.ReLU(),
            # 最大池化,减少特征量,选特征最大的数,是一种下采样
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self._conv2_layer = nn.Sequential(
            nn.Conv2d(15,30,5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self._full_layer = nn.Sequential(
            # 卷积层都是四维张量,展平为二维张量给连接层用
            nn.Flatten(),
            nn.Linear(in_features=480, out_features=60),
            nn.ReLU(),
            nn.Linear(in_features=60, out_features=10),
        )
    
    def forward(self, input):
        # 层层连接,两个卷积层,最后全连接层
        output = self._conv1_layer(input)
        output = self._conv2_layer(output)
        output = self._full_layer(output)
        return output

class Test:
    def __init__(self):
        # MNIST数据集,用于训练,一次抓60 size
        self._train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data/', train=True, download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.1307,), (0.3081,))
                                    ])),
            batch_size=60, shuffle=True)
        # 用于测试,一次抓500 size
        self._test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST('./data/', train=False, download=True,
                                    transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.1307,), (0.3081,))
                                    ])),
            batch_size=500, shuffle=True)
        # 训练次数
        self._epochs = 3
        self._cnn = CNN()
        # 交叉熵损失函数,刻画的是两个概率分布的距离,交叉熵越小,概率分布越接近
        self._loss_func = nn.CrossEntropyLoss()
        # 优化器
        self._optim = torch.optim.Adam(self._cnn.parameters(), lr=0.01)
        if torch.cuda.is_available():
            print("Use CUDA training!")
            self._device = torch.device("cuda")
        else:
            print("Use CPU training!")
            self._device = torch.device("cpu")
        
    def train(self):
        loss_d = []
        for epoch in range(1, self._epochs + 1):
            self._cnn.train(mode=True)
            for idx, (train_img, train_label) in enumerate(self._train_loader):
                # 复制到device中
                train_img = train_img.to(self._device)
                train_label = train_label.to(self._device)
                outputs = self._cnn(train_img)
                # 清除梯度
                self._optim.zero_grad()
                loss = self._loss_func(outputs, train_label)
                # 反向传播  
                loss.backward()
                # 更新权重
                self._optim.step()
                # print('Train epoch {}: loss: {:.6f}'.format(epoch,loss.item()))
                loss_d.append(loss.item())
        plt.plot(range(0,len(loss_d)),loss_d)
        plt.show()

    def test(self):
        correct_num = 0
        total_num = 0
        loss_d = []
        self._cnn.train(mode=False)

        with torch.no_grad():
            for idx, (test_img, test_label) in enumerate(self._test_loader):
                test_img = test_img.to(self._device)
                test_label = test_label.to(self._device)

                total_num += test_label.size(0)

                outputs = self._cnn(test_img)
                loss = self._loss_func(outputs, test_label)
                loss_d.append(loss.item())

                predictions = torch.argmax(outputs, dim=1)
                correct_num += torch.sum(predictions == test_label)
        acc_num = ((correct_num.item()/total_num)*100)
        title_str ="Accuracy:"+str(acc_num)+"%"
        plt.title(title_str)
        plt.plot(range(0,len(loss_d)),loss_d)
        plt.show()
            
    def plotTestResult(self):
        iteration = enumerate(self._test_loader)
        idx, (test_img, test_label) = next(iteration)

        with torch.no_grad():
            outputs = self._cnn(test_img)

            fig = plt.figure()
            for i in range(4 * 2):
                plt.subplot(4, 2, i + 1)
                plt.tight_layout()
                plt.imshow(test_img[0], cmap='gray', interpolation='none')
                plt.title('real: {}, predict: {}'.format(
                    test_label, outputs.data.max(1, keepdim=True)[1].item()
                ))
                plt.xticks([])
                plt.yticks([])
            plt.show()

    def savePthModel(self, pth_name:str):
        torch.save(self._cnn.state_dict(), pth_name)

    def saveOnnxModel(self, onnx_name:str):
        input = torch.randn(1,1,28,28)
        torch.onnx.export(self._cnn, input, onnx_name, verbose=True)

    
if __name__ == "__main__":
    mt = Test()
    mt.train()
    mt.test()
    mt.plotTestResult()
    mt.savePthModel("model.pth")
    mt.saveOnnxModel("model.onnx")

Video Explanation


This post is from Embedded System

Latest reply

Study hard, make progress every day, come on everyone, come on yourself, come on!!!   Details Published on 2024-10-29 21:11

2865

Posts

4

Resources
2
 

I have no interest in python, so I won’t work on pytorch. But I hope you can actively participate in my post on what AI is really like .

This post is from Embedded System

Comments

666, make a "hair"  Details Published on 2024-4-19 15:09
 
 

364

Posts

0

Resources
3
 
bigbat posted on 2024-4-19 11:21 I have no interest in python, so I won’t do pytorch, but I hope my post on AI can help you...

666, make a "hair"

This post is from Embedded System
 
 
 

5998

Posts

6

Resources
4
 

If this is laid out in the MCU, how to generate the C file

This post is from Embedded System

Comments

C should have other reasoning frameworks, and pytorch is definitely not available  Details Published on 2024-4-21 10:29
 
Personal signature

在爱好的道路上不断前进,在生活的迷雾中播撒光引

 
 

364

Posts

0

Resources
5
 
Qintianqintian0303 posted on 2024-4-19 23:33 If this is laid out in the MCU, how to generate a C file

C should have other reasoning frameworks, and pytorch is definitely not available

This post is from Embedded System
 
 
 

718

Posts

4

Resources
6
 

Thank you for sharing the technical content information, which is very detailed and of great practical value. It is worth learning.

This post is from Embedded System

Comments

Thank you for your support, let’s make progress together  Details Published on 2024-4-21 21:54
 
 
 

364

Posts

0

Resources
7
 
chejm posted on 2024-4-21 21:25 Thank you for sharing the technical content and information, which is very detailed and of great practical value. It is worth learning

Thank you for your support, let’s make progress together

This post is from Embedded System
 
 
 

59

Posts

0

Resources
8
 

The best contribution in my mind goes to the OP...very helpful

This post is from Embedded System

Comments

Hahahaha, thanks  Details Published on 2024-4-29 22:10
 
 
 

364

Posts

0

Resources
9
 
crimsonsnow posted on 2024-4-25 11:12 The best contribution in my mind to the OP...very helpful

Hahahaha, thanks

This post is from Embedded System
 
 
 

409

Posts

0

Resources
10
 

Study hard, make progress every day, come on everyone, come on yourself, come on!!!

This post is from Embedded System

Comments

come on  Details Published on 2024-10-29 21:27
 
 
 

364

Posts

0

Resources
11
 
Tongtu Technology published on 2024-10-29 21:11 Study hard, make progress every day, cheer for everyone, cheer for yourself, cheer !!!

come on

This post is from Embedded System
 
 
 

Just looking around
Find a datasheet?

EEWorld Datasheet Technical Support

EEWorld
subscription
account

EEWorld
service
account

Automotive
development
circle

Copyright © 2005-2024 EEWORLD.com.cn, Inc. All rights reserved 京B2-20211791 京ICP备10001474号-1 电信业务审批[2006]字第258号函 京公网安备 11010802033920号
快速回复 返回顶部 Return list