pytorch如何實(shí)現(xiàn)用CNN和LSTM對(duì)文本進(jìn)行分類(lèi)方式-創(chuàng)新互聯(lián)

小編給大家分享一下pytorch如何實(shí)現(xiàn)用CNN和LSTM對(duì)文本進(jìn)行分類(lèi)方式,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!

專(zhuān)業(yè)從事成都網(wǎng)站制作、網(wǎng)站建設(shè),高端網(wǎng)站制作設(shè)計(jì),微信小程序,網(wǎng)站推廣的成都做網(wǎng)站的公司。優(yōu)秀技術(shù)團(tuán)隊(duì)竭力真誠(chéng)服務(wù),采用H5場(chǎng)景定制+CSS3前端渲染技術(shù),成都響應(yīng)式網(wǎng)站建設(shè)公司,讓網(wǎng)站在手機(jī)、平板、PC、微信下都能呈現(xiàn)。建站過(guò)程建立專(zhuān)項(xiàng)小組,與您實(shí)時(shí)在線互動(dòng),隨時(shí)提供解決方案,暢聊想法和感受。

model.py:

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
 
class TextRNN(nn.Module):
  """文本分類(lèi),RNN模型"""
  def __init__(self):
    super(TextRNN, self).__init__()
    # 三個(gè)待輸入的數(shù)據(jù)
    self.embedding = nn.Embedding(5000, 64) # 進(jìn)行詞嵌入
    # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.f1 = nn.Sequential(nn.Linear(256,128),
                nn.Dropout(0.8),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128,10),
                nn.Softmax())
 
  def forward(self, x):
    x = self.embedding(x)
    x,_ = self.rnn(x)
    x = F.dropout(x,p=0.8)
    x = self.f1(x[:,-1,:])
    return self.f2(x)
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(256*596, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,256*596)
    x = self.f1(x)
    return self.f2(x)

train.py:

# coding: utf-8
 
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import os
 
import numpy as np
 
from model import TextRNN,TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
 
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
 
def train():
  x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#獲取訓(xùn)練數(shù)據(jù)每個(gè)字的id和對(duì)應(yīng)標(biāo)簽的oe-hot形式
  x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
  #使用LSTM或者CNN
  model = TextRNN()
  # model = TextCNN()
  #選擇損失函數(shù)
  Loss = nn.MultiLabelSoftMarginLoss()
  # Loss = nn.BCELoss()
  # Loss = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(),lr=0.001)
  best_val_acc = 0
  for epoch in range(1000):
    batch_train = batch_iter(x_train, y_train,100)
    for x_batch, y_batch in batch_train:
      x = np.array(x_batch)
      y = np.array(y_batch)
      x = torch.LongTensor(x)
      y = torch.Tensor(y)
      # y = torch.LongTensor(y)
      x = Variable(x)
      y = Variable(y)
      out = model(x)
      loss = Loss(out,y)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
    #對(duì)模型進(jìn)行驗(yàn)證
    if (epoch+1)%20 == 0:
      batch_val = batch_iter(x_val, y_val, 100)
      for x_batch, y_batch in batch_train:
        x = np.array(x_batch)
        y = np.array(y_batch)
        x = torch.LongTensor(x)
        y = torch.Tensor(y)
        # y = torch.LongTensor(y)
        x = Variable(x)
        y = Variable(y)
        out = model(x)
        loss = Loss(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
        if accracy > best_val_acc:
          torch.save(model.state_dict(),'model_params.pkl')
          best_val_acc = accracy
        print(accracy)
 
if __name__ == '__main__':
  #獲取文本的類(lèi)別及其對(duì)應(yīng)id的字典
  categories, cat_to_id = read_category()
  #獲取訓(xùn)練文本中所有出現(xiàn)過(guò)的字及其所對(duì)應(yīng)的id
  words, word_to_id = read_vocab(vocab_dir)
  #獲取字?jǐn)?shù)
  vocab_size = len(words)
  train()

test.py:

# coding: utf-8
 
from __future__ import print_function
 
import os
import tensorflow.contrib.keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from model import TextRNN
from torch.autograd import Variable
import numpy as np
try:
  bool(type(unicode))
except NameError:
  unicode = str
 
base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(152576, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,152576)
    x = self.f1(x)
    return self.f2(x)
 
class CnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextCNN()
    self.model.load_state_dict(torch.load('model_params.pkl'))
 
  def predict(self, message):
    # 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運(yùn)行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data],600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
class RnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextRNN()
    self.model.load_state_dict(torch.load('model_rnn_params.pkl'))
 
  def predict(self, message):
    # 支持不論在python2還是python3下訓(xùn)練的模型都可以在2或者3的環(huán)境下運(yùn)行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data], 600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
 
if __name__ == '__main__':
  model = CnnModel()
  # model = RnnModel()
  test_demo = ['湖人助教力助科比恢復(fù)手感 他也是阿泰的精神導(dǎo)師新浪體育訊記者戴高樂(lè)報(bào)道 上賽季,科比的右手食指遭遇重創(chuàng),他的投籃手感也因此大受影響。不過(guò)很快科比就調(diào)整了自己的投籃手型,并通過(guò)這一方式讓自己的投籃命中率回升。而在這科比背后,有一位特別助教對(duì)科比幫助很大,他就是查克·珀森。珀森上賽季擔(dān)任湖人的特別助教,除了幫助科比調(diào)整投籃手型之外,他的另一個(gè)重要任務(wù)就是擔(dān)任阿泰的精神導(dǎo)師。來(lái)到湖人隊(duì)之后,阿泰收斂起了暴躁的脾氣,成為湖人奪冠路上不可或缺的一員,珀森的“心靈按摩”功不可沒(méi)。經(jīng)歷了上賽季的成功之后,珀森本賽季被“升職”成為湖人隊(duì)的全職助教,每場(chǎng)比賽,他都會(huì)坐在球場(chǎng)邊,幫助禪師杰克遜一起指揮湖人球員在場(chǎng)上拼殺。對(duì)于珀森的工作,禪師非常欣賞,“查克非常善于分析問(wèn)題,”菲爾·杰克遜說(shuō),“他總是在尋找問(wèn)題的答案,同時(shí)也在找造成這一問(wèn)題的原因,這是我們都非常樂(lè)于看到的。我會(huì)在平時(shí)把防守中出現(xiàn)的一些問(wèn)題交給他,然后他會(huì)通過(guò)組織球員練習(xí)找到解決的辦法。他在球員時(shí)代曾是一名很好的外線投手,不過(guò)現(xiàn)在他與內(nèi)線球員的配合也相當(dāng)不錯(cuò)。',
         '弗老大被裁美國(guó)媒體看熱鬧“特權(quán)”在中國(guó)像蠢蛋弗老大要走了。雖然他只在首鋼男籃效力了13天,而且表現(xiàn)毫無(wú)亮點(diǎn),大大地讓球迷和俱樂(lè)部失望了,但就像中國(guó)人常說(shuō)的“好聚好散”,隊(duì)友還是友好地與他告別,俱樂(lè)部與他和平分手,球迷還請(qǐng)他留下了在北京的最后一次簽名。相比之下,弗老大的同胞美國(guó)人卻沒(méi)那么“寬容”。他們嘲諷這位NBA前巨星的英雄遲暮,批評(píng)他在CBA的業(yè)余表現(xiàn),還驚訝于中國(guó)人的“大方”。今天,北京首鋼俱樂(lè)部將與弗朗西斯繼續(xù)商討解約一事。從昨日的進(jìn)展來(lái)看,雙方可以做到“買(mǎi)賣(mài)不成人意在”,但回到美國(guó)后,恐怕等待弗朗西斯的就沒(méi)有這么輕松的環(huán)境了。進(jìn)展@北京昨日與隊(duì)友告別 最后一次為球迷簽名弗朗西斯在13天里為首鋼隊(duì)打了4場(chǎng)比賽,3場(chǎng)的得分為0,只有一場(chǎng)得了2分。昨天是他來(lái)到北京的第14天,雖然他與首鋼還未正式解約,但雙方都明白“緣分已盡”。下午,弗朗西斯來(lái)到首鋼俱樂(lè)部與隊(duì)友們告別。弗朗西斯走到隊(duì)友身邊,依次與他們握手擁抱?!澳銈兌紝?duì)我很好,安排的條件也很好,我很喜歡這支球隊(duì),想融入你們,但我現(xiàn)在真的很不適應(yīng)。希望你們']
  for i in test_demo:
    print(i,":",model.predict(i))

cnews_loader.py:

# coding: utf-8
 
import sys
from collections import Counter
 
import numpy as np
import tensorflow.contrib.keras as kr
 
if sys.version_info[0] > 2:
  is_py3 = True
else:
  reload(sys)
  sys.setdefaultencoding("utf-8")
  is_py3 = False
 
 
def native_word(word, encoding='utf-8'):
  """如果在python2下面使用python3訓(xùn)練的模型,可考慮調(diào)用此函數(shù)轉(zhuǎn)化一下字符編碼"""
  if not is_py3:
    return word.encode(encoding)
  else:
    return word
 
 
def native_content(content):
  if not is_py3:
    return content.decode('utf-8')
  else:
    return content
 
 
def open_file(filename, mode='r'):
  """
  常用文件操作,可在python2和python3間切換.
  mode: 'r' or 'w' for read or write
  """
  if is_py3:
    return open(filename, mode, encoding='utf-8', errors='ignore')
  else:
    return open(filename, mode)
 
 
def read_file(filename):
  """讀取文件數(shù)據(jù)"""
  contents, labels = [], []
  with open_file(filename) as f:
    for line in f:
      try:
        label, content = line.strip().split('\t')
        if content:
          contents.append(list(native_content(content)))
          labels.append(native_content(label))
      except:
        pass
  return contents, labels
 
 
def build_vocab(train_dir, vocab_dir, vocab_size=5000):
  """根據(jù)訓(xùn)練集構(gòu)建詞匯表,存儲(chǔ)"""
  data_train, _ = read_file(train_dir)
 
  all_data = []
  for content in data_train:
    all_data.extend(content)
 
  counter = Counter(all_data)
  count_pairs = counter.most_common(vocab_size - 1)
  words, _ = list(zip(*count_pairs))
  # 添加一個(gè) <PAD> 來(lái)將所有文本pad為同一長(zhǎng)度
  words = ['<PAD>'] + list(words)
  open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
 
 
def read_vocab(vocab_dir):
  """讀取詞匯表"""
  # words = open_file(vocab_dir).read().strip().split('\n')
  with open_file(vocab_dir) as fp:
    # 如果是py2 則每個(gè)值都轉(zhuǎn)化為unicode
    words = [native_content(_.strip()) for _ in fp.readlines()]
  word_to_id = dict(zip(words, range(len(words))))
  return words, word_to_id
 
 
def read_category():
  """讀取分類(lèi)目錄,固定"""
  categories = ['體育', '財(cái)經(jīng)', '房產(chǎn)', '家居', '教育', '科技', '時(shí)尚', '時(shí)政', '游戲', '娛樂(lè)']
 
  categories = [native_content(x) for x in categories]
 
  cat_to_id = dict(zip(categories, range(len(categories))))
 
  return categories, cat_to_id
 
 
def to_words(content, words):
  """將id表示的內(nèi)容轉(zhuǎn)換為文字"""
  return ''.join(words[x] for x in content)
 
 
def process_file(filename, word_to_id, cat_to_id, max_length=600):
  """將文件轉(zhuǎn)換為id表示"""
  contents, labels = read_file(filename)#讀取訓(xùn)練數(shù)據(jù)的每一句話及其所對(duì)應(yīng)的類(lèi)別
  data_id, label_id = [], []
  for i in range(len(contents)):
    data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#將每句話id化
    label_id.append(cat_to_id[labels[i]])#每句話對(duì)應(yīng)的類(lèi)別的id
  #
  # # 使用keras提供的pad_sequences來(lái)將文本pad為固定長(zhǎng)度
  x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
  y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 將標(biāo)簽轉(zhuǎn)換為one-hot表示
  #
  return x_pad, y_pad
 
 
def batch_iter(x, y, batch_size=64):
  """生成批次數(shù)據(jù)"""
  data_len = len(x)
  num_batch = int((data_len - 1) / batch_size) + 1
 
  indices = np.random.permutation(np.arange(data_len))
  x_shuffle = x[indices]
  y_shuffle = y[indices]
 
  for i in range(num_batch):
    start_id = i * batch_size
    end_id = min((i + 1) * batch_size, data_len)
    yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

pytorch的優(yōu)點(diǎn)

1.PyTorch是相當(dāng)簡(jiǎn)潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類(lèi)思維,它讓用戶(hù)盡可能地專(zhuān)注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類(lèi)似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開(kāi)發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶(hù)交流和求教問(wèn)題6.入門(mén)簡(jiǎn)單

以上是“pytorch如何實(shí)現(xiàn)用CNN和LSTM對(duì)文本進(jìn)行分類(lèi)方式”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對(duì)大家有所幫助,如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計(jì)公司行業(yè)資訊頻道!

另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無(wú)理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性?xún)r(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專(zhuān)為企業(yè)上云打造定制,能夠滿(mǎn)足用戶(hù)豐富、多元化的應(yīng)用場(chǎng)景需求。

當(dāng)前名稱(chēng):pytorch如何實(shí)現(xiàn)用CNN和LSTM對(duì)文本進(jìn)行分類(lèi)方式-創(chuàng)新互聯(lián)
文章路徑:http://bm7419.com/article0/ijiio.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供Google、ChatGPT網(wǎng)站內(nèi)鏈、網(wǎng)站策劃、微信公眾號(hào)、做網(wǎng)站

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

小程序開(kāi)發(fā)