PyTorch Tutorial RNN 2

date:

22 oct 2019

download data

from io import open;import re;import unicodedata
import numpy as np;import random

class Language:
  def __init__(self):
    self.word2index={}; self.word2count={}
    self.index2word={0:"SOS",1:"EOS"};self.nWords=2
  def addSentence(self,sentence):
    for word in sentence.split(' '):
      if word not in self.word2index:
        self.word2index[word]=self.nWords
        self.word2count[word]=1
        self.index2word[self.nWords]=word
        self.nWords+=1
      else:
        self.word2count[word]+=1

class GetTranslation:
  def __init__(self,file,maxLen=10):
    self.maxLen=maxLen
    lines=open(file,encoding='utf-8').read().strip().split('\n')
    pairs=[[self.normalizeString(s) for s in line.split('\t')]
           for line in lines]
    self.lang1=Language(); self.lang2=Language()
    self.pairs=self.filterPairs(pairs);print("len(pairs)=",len(pairs))
    for pair in self.pairs:
      self.lang1.addSentence(pair[0])
      self.lang2.addSentence(pair[1])
    print("lang1",self.lang1.nWords);print("lang2",self.lang2.nWords)

  def normalizeString(self,s):
    s=self.unicodeToAscii(s.lower().strip())
    s=re.sub(r"([.!?])",r" \1",s);s=re.sub(r"[^a-zA-Z.!?]+",r" ",s)
    return s
  def unicodeToAscii(self,s):
    return ''.join(c for c in unicodedata.normalize('NFD',s)
      if unicodedata.category(c)!='Mn')

  def filterPairs(self,pairs):
    return [pair for pair in pairs if self.keepPair(pair)]
  def keepPair(self,pair):
    englishPrefixes = ("i am ", "i m ","he is", "he s ",
      "she is", "she s ","you are", "you re ",
      "we are", "we re ","they are", "they re ")
    return len(pair[0].split(' '))<self.maxLen\
      and len(pair[1].split(' '))<self.maxLen\
      and pair[0].startswith(englishPrefixes)

  def samplePair(self):
    pair=random.choice(self.pairs)
    #print("pair",pair)
    inputSeq=self.sentenceToSeq(self.lang1,pair[0])
    targetSeq=self.sentenceToSeq(self.lang2,pair[1])
    return inputSeq,targetSeq
  def sentenceToSeq(self,language,sentence):
    seq=[language.word2index[word] for word in sentence.split(' ')]
    seq.append(1) # "EOS = 1"
    return seq

getTranslation=GetTranslation('eng-fra.txt')
getTranslation.samplePair()
len(pairs)= 135842
lang1 2803
lang2 4345
([129, 124, 713, 606, 695, 4, 1], [210, 925, 1328, 115, 528, 2744, 5, 1])
import torch; import torch.nn as nn; from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

class EncoderRNN(nn.Module):
  def __init__(self,inputSize,hiddenSize):
    super(EncoderRNN,self).__init__()
    self.hiddenSize=hiddenSize
    self.embedding=nn.Embedding(inputSize,hiddenSize)
    self.gru=nn.GRU(input_size=hiddenSize,hidden_size=hiddenSize)
  def forward(self,input,hidden):
    embedded=self.embedding(input).view([1,1,-1])
    return self.gru(embedded,hidden) # output, hidden
  def initHidden(self):
    return torch.zeros(1,1,self.hiddenSize,device=device)
"""
class DecoderRNN(nn.Module):
  def __init__(self,hiddenSize,outputSize):
    super(DecoderRNN,self).__init__()
    self.hiddenSize=hiddenSize
    self.embedding=nn.Embedding(outputSize,hiddenSize)
    self.gru=nn.GRU(input_size=hiddenSize,hidden_size=hiddenSize)
    self.outLinear=nn.Linear(hiddenSize,outputSize)
    self.softmax=nn.LogSoftmax(dim=1)
  def forward(self,input,hidden):
    embedded=F.relu(self.embedding(input).view(1,1,-1))
    output,hidden=self.gru(embeded,hidden)
    output=self.softmax(self.outLinear(output[0]))
    return output,hidden
  def initHidden(self):
    return torch.zeros(1,1,self.hiddenSize,device=device)
"""
class AttentionDecoder(nn.Module):
  def __init__(self,hiddenSize,outputSize,dropoutProb,maxLen):
    super(AttentionDecoder,self).__init__()
    self.hiddenSize=hiddenSize
    self.embedding=nn.Embedding(outputSize,hiddenSize)
    self.attention=nn.Linear(hiddenSize*2,maxLen)
    self.attentionCombine=nn.Linear(hiddenSize*2,hiddenSize)
    self.dropout=nn.Dropout(dropoutProb)
    self.gru=nn.GRU(input_size=hiddenSize,hidden_size=hiddenSize)
    self.outLinear=nn.Linear(hiddenSize,outputSize)
  def forward(self,input,hidden,encoderOutputs):
    embedded=self.dropout(self.embedding(input).view(1,1,-1))
    # see graph
    attentionWeights=F.softmax(
        self.attention(
            torch.cat((embedded[0],hidden[0]),dim=1)
        ),dim=1
    )
    attentionApplied=torch.bmm(attentionWeights.unsqueeze(0),
                               encoderOutputs.unsqueeze(0))
    output=torch.cat((embedded[0],attentionApplied[0]),1)
    output=F.relu(self.attentionCombine(output).unsqueeze(0))
    output,hidden=self.gru(output,hidden)
    output=F.log_softmax(self.outLinear(output[0]),dim=1)
    return output,hidden,attentionWeights
  def initHidden(self):
    return torch.zeros([1,1,self.hiddenSize],device=device)

class Agent:
  def __init__(self,maxLen=10,hiddenSize=256,
               teacherForcingProb=0.5,filename='eng-fra.txt',
               dropoutProb=0.01,learnRate=1e-2):
    self.maxLen=maxLen; self.hiddenSize=hiddenSize
    self.teacherForcingProb=teacherForcingProb
    self.getTranslation=GetTranslation(filename,maxLen)
    lang1nWords=self.getTranslation.lang1.nWords
    lang2nWords=self.getTranslation.lang2.nWords
    self.encoder=EncoderRNN(lang1nWords,hiddenSize).to(device)
    self.encoderOptimizer=optim.SGD(self.encoder.parameters(),
                                    lr=learnRate)
    self.decoder=AttentionDecoder(hiddenSize,lang2nWords,
                                  dropoutProb,maxLen).to(device)
    self.decoderOptimizer=optim.SGD(self.decoder.parameters(),
                                    lr=learnRate)
    self.criterion=nn.NLLLoss()

  def train(self,nIterations):
    self.losses=[];lossTotal=0
    for iter in range(nIterations):
      pair=self.getTranslation.samplePair()
      inputTensor=torch.tensor(pair[0],dtype=torch.long,device=device).view(-1, 1)
      targetTensor=torch.tensor(pair[1],dtype=torch.long,device=device).view(-1, 1)
      loss=self.trainOne(inputTensor,targetTensor)
      lossTotal+=loss
      if iter%100==0:
        self.losses.append(lossTotal)
        print("lossTotal:",lossTotal);lossTotal=0
    plt.plot(self.losses);plt.show()
  def trainOne(self,inputTensor,targetTensor):
    encoderHidden=self.encoder.initHidden()
    self.encoderOptimizer.zero_grad();self.decoderOptimizer.zero_grad()
    inputLen=inputTensor.size(0); targetLen=targetTensor.size(0)
    encoderOutputs=torch.zeros([self.maxLen,self.hiddenSize],device=device)
    loss=0
    for i in range(inputLen):
      encoderOutput,encoderHidden=self.encoder(
          inputTensor[i],encoderHidden)
      encoderOutputs[i]=encoderOutput[0,0]
    decoderInput=torch.tensor([[0]],device=device) # "SOS = 0"
    decoderHidden=encoderHidden
    useTeacherForcing=True if random.random()<self.teacherForcingProb else False
    if useTeacherForcing: # target as next input
      for i in range(targetLen):
        decoderOutput,decoderHidden,decoderAttention=self.decoder(
            decoderInput,decoderHidden,encoderOutputs)
        loss+=self.criterion(decoderOutput,targetTensor[i])
        decoderInput=targetTensor[i]
    else: # prediction as next input
      for i in range(targetLen):
        decoderOutput,decoderHidden,decoderAttention=self.decoder(
            decoderInput,decoderHidden,encoderOutputs)
        topValue,topIndex=decoderOutput.topk(1)
        decoderInput=topIndex.squeeze().detach() # detach from graph, just an input
        loss+=self.criterion(decoderOutput,targetTensor[i])
        if decoderInput.item()==1: break  # "EOS = 1"
    loss.backward();
    self.encoderOptimizer.step();self.decoderOptimizer.step()
    return loss.item()/targetLen

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent=Agent()
agent.train(5000)
len(pairs)= 135842
lang1 2803
lang2 4345
lossTotal: 8.382897271050346
lossTotal: 484.00090956952846
lossTotal: 415.5288609262497
lossTotal: 398.90406545078946
lossTotal: 392.74038437283235
lossTotal: 401.41351437265905
lossTotal: 413.63684746575706
lossTotal: 423.95259078987056
lossTotal: 387.9253147038201
lossTotal: 400.4369891666231
lossTotal: 375.9564390962088
lossTotal: 369.40357372420175
lossTotal: 375.159505275696
lossTotal: 368.3363914232405
lossTotal: 351.4931094124203
lossTotal: 345.763414468084
lossTotal: 340.3649337579335
lossTotal: 316.76592836039407
lossTotal: 336.15329839994024
lossTotal: 331.5485665635458
lossTotal: 338.0610604653282
lossTotal: 322.1913966966053
lossTotal: 319.4010247374338
lossTotal: 340.254363073243
lossTotal: 332.29101278441306
lossTotal: 337.0906412942068
lossTotal: 314.36265838524656
lossTotal: 325.3588539819868
lossTotal: 318.25481532036326
lossTotal: 315.2033801824328
lossTotal: 341.89669475025596
lossTotal: 317.06151219171187
lossTotal: 324.6426379143246
lossTotal: 335.2526238774497
lossTotal: 312.68326999043666
lossTotal: 311.76144389992675
lossTotal: 308.2723661736837
lossTotal: 317.88990014621186
lossTotal: 292.26333829259113
lossTotal: 296.26249667103343
lossTotal: 323.83813181585725
lossTotal: 319.55471047219766
lossTotal: 300.0589878778609
lossTotal: 298.32901718692165
lossTotal: 315.3292742059345
lossTotal: 291.8655293184614
lossTotal: 306.4530454408553
lossTotal: 319.7169124440542
lossTotal: 293.4895835475317
lossTotal: 308.7035841949402
../_images/PyTorch-Tutorial-RNN-2-output_2_1.webp