pytorch中cnn如何識(shí)別手寫的字并實(shí)現(xiàn)自建圖片數(shù)據(jù)-創(chuàng)新互聯(lián)

小編給大家分享一下pytorch中cnn如何識(shí)別手寫的字并實(shí)現(xiàn)自建圖片數(shù)據(jù),相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!

創(chuàng)新互聯(lián)建站堅(jiān)持“要么做到,要么別承諾”的工作理念,服務(wù)領(lǐng)域包括:成都網(wǎng)站制作、成都網(wǎng)站建設(shè)、外貿(mào)營(yíng)銷網(wǎng)站建設(shè)、企業(yè)官網(wǎng)、英文網(wǎng)站、手機(jī)端網(wǎng)站、網(wǎng)站推廣等服務(wù),滿足客戶于互聯(lián)網(wǎng)時(shí)代的上猶網(wǎng)站設(shè)計(jì)、移動(dòng)媒體設(shè)計(jì)的需求,幫助企業(yè)找到有效的互聯(lián)網(wǎng)解決方案。努力成為您成熟可靠的網(wǎng)絡(luò)建設(shè)合作伙伴!

具體如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

圖片和label 見上一篇文章《pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt》

結(jié)果如下:

pytorch中cnn如何識(shí)別手寫的字并實(shí)現(xiàn)自建圖片數(shù)據(jù)

以上是“pytorch中cnn如何識(shí)別手寫的字并實(shí)現(xiàn)自建圖片數(shù)據(jù)”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道!

當(dāng)前標(biāo)題:pytorch中cnn如何識(shí)別手寫的字并實(shí)現(xiàn)自建圖片數(shù)據(jù)-創(chuàng)新互聯(lián)
文章分享:http://bm7419.com/article10/ddpcgo.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供網(wǎng)站設(shè)計(jì)、域名注冊(cè)、靜態(tài)網(wǎng)站、App開發(fā)、網(wǎng)站排名網(wǎng)站內(nèi)鏈

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)

小程序開發(fā)