Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    CIFAR-10数据集实战——构建LeNet5神经网络

    语音识别与语义处理领域
    1
    1
    39
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by

      CIFAR-10数据集网站

      如果从官网下载数据集很慢,可以使用国内的地址http://ai-atest.bj.bcebos.com/cifar-10-python.tar.gz

      MNIST数据集为0~9的数字,而CIFAR-10数据集为10类物品识别,包含飞机、车、鸟、猫等。照片大小为32*32的彩色图片(三通道)。每个类别大概有6000张照片,其中随机筛选出5000用来training,剩下的1000用来testing

      首先引入数据集

      import torch
      from torch.utils.data import DataLoader
      from torchvision import datasets, transforms
      
      batch_size=32
      
      cifar_train = datasets.CIFAR10(root='cifar', train=True, transform=transforms.Compose([
          transforms.Resize([32, 32]),
          transforms.ToTensor(),
      ]), download=True)
          
      cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)
      
      cifar_test = datasets.CIFAR10(root='cifar', train=False, transform=transforms.Compose([
          transforms.Resize([32, 32]),
          transforms.ToTensor(),
      ]), download=True)
          
      cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)
          
      x, label = iter(cifar_train).next()
      print('x:', x.shape, 'label:', label.shape)
      

      引入数据集以后,接下来开始编写经典的LeNet5神经网络

      import torch
      from torch import nn, optim
      import torch.nn.functional as F
      
      class LeNet5(nn.Module):
          """
          for CIFAR10 datasets
          """
          def __init__(self):
              super(LeNet5, self).__init__()
              self.conv_unit = nn.Sequential(
                  # x: [batchsize, 3, 32, 32] => [batchsize, 6, 28, 28]
                  nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=0),
                  # [batchsize, 6, 28, 28] => [batchsize, 6, 14, 14]
                  nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                  # [batchsize, 6, 14, 14] => [batchsize, 16, 10, 10]
                  nn.Conv2d(6, 16, 5, 1, 0),
                  # [batchsize, 16, 10, 10] => [batchsize, 16, 5, 5]
                  nn.AvgPool2d(2, 2, 0)
                  
              )
              
              # fc_unit
              self.fc_unit = nn.Sequential(
                  nn.Linear(in_features=16*5*5, out_features=120),
                  nn.ReLU(),
                  nn.Linear(120, 84),
                  nn.ReLU(),
                  nn.Linear(84, 10)
              )        
              
          def forward(self, x):
              batchsize = x.size(0)
              # [b, 3, 32, 32] => [b, 16, 5, 5]
              x = self.conv_unit(x)
              
              # [b, 16, 5, 5] => [b, 16*5*5]
              x = x.view(batchsize, -1)
              
              # [b, 16*5*5] => [b, 10]
              logits = self.fc_unit(x)
              
              return logits
              
      def main():
      
          ##########  train  ##########
          #device = torch.device('cuda')
          #model = LeNet5().to(device)
          criteon = nn.CrossEntropyLoss()
          model = LeNet5()
          optimizer = optim.Adam(model.parameters(), 1e-3)
          for epoch in range(1000):
              model.train()
              for batchidx, (x, label) in enumerate(cifar_train):
                  #x, label = x.to(device), label.to(device)
                  logits = model(x)
                  # logits: [b, 10]
                  # label:  [b]
                  loss = criteon(logits, label)
                  
                  # backward
                  optimizer.zero_grad()
                  loss.backward()
                  optimizer.step()
              
              print('train:', epoch, loss.item())
              
              ########## test  ##########
              model.eval()
              with torch.no_grad():
                  total_correct = 0
                  total_num = 0
                  for x, label in cifar_test:
                      # x, label = x.to(device), label.to(device)
      
                      # [b]
                      logits = model(x)
                      # [b]
                      pred = logits.argmax(dim=1)
                      # [b] vs [b]
                      total_correct += torch.eq(pred, label).float().sum().item()
                      total_num += x.size(0)
                  acc = total_correct / total_num
                  print('test:', epoch, acc)
      
      if __name__ == '__main__':
          main()
      

      从这一部分的运行情况来看,准确率在慢慢上升,但并不稳定,读者有兴趣可以尝试自己修改网络结构,使其准确率更高

      1 Reply Last reply Reply Quote 0
      • First post
        Last post