PyTorch Tutorial Transformer

Date:

23 Oct 2019

import os; from io import open
import math
import torch; import torch.nn as nn; import torch.nn.functional as F

Transformer Model

class PosEncoder(nn.Module):
  def __init__(self,embedDim,dropoutProb=0.1,maxPos=5000):
    super(PosEncoder,self).__init__()
    self.dropout=nn.Dropout(p=dropoutProb)
    posEnc=torch.zeros(maxPos,embedDim)
    position=torch.arange(0,maxPos).float().unsqueeze(1)
    divTerm=torch.exp(torch.arange(0,embedDim,2).float()*\
                      (-math.log(1e4)/embedDim))
    posEnc[:,0::2]=torch.sin(position*divTerm)
    posEnc[:,1::2]=torch.cos(position*divTerm)
    posEnc=posEnc.unsqueeze(0).transpose(0,1)
    self.register_buffer('posEnc',posEnc)
    # This is used to register a buffer that
    # should not to be considered a model parameter
  def forward(self,x):
    # x.shape = [seqLen, batchSize, embedDim]
    # output.shape = x.shape
    return self.dropout(x+self.posEnc[:x.size(0), :])

class Transformer(nn.Module):
  def __init__(self,numEmbeds,embedDim,nHeads,feedforwardDim,
               nLayers,dropoutProb=0.5):
    super(Transformer,self).__init__()
    self.embedDim=embedDim; self.numEmbeds=numEmbeds

    from torch.nn import TransformerEncoder, TransformerEncoderLayer
    self.inputSeqMask=None
    self.posEncoder=PosEncoder(embedDim,dropoutProb)
    encoderLayers=TransformerEncoderLayer(embedDim,nHeads,
                                          feedforwardDim,dropoutProb)
    self.transformerEncoder=TransformerEncoder(encoderLayers,nLayers)
    self.embedding=nn.Embedding(numEmbeds,embedDim)
    self.decoderLinear=nn.Linear(embedDim,numEmbeds)
    self.initWeights()
  def initWeights(self):
    initRange=0.1
    self.embedding.weight.data.uniform_(-initRange,initRange)
    self.decoderLinear.bias.data.zero_()
    self.decoderLinear.weight.data.uniform_(-initRange,initRange)
  def forward(self,inputSeq,hasMask=True):
    if hasMask:
      device=inputSeq.device
      if self.inputSeqMask is None or self.inputSeqMask.size(0)!=len(inputSeq):
        self.inputSeqMask=self.squareSubsequentMast(len(inputSeq)).to(device)
    else:
      self.inputSeqMask=None
    inputSeq=self.posEncoder(self.embedding(inputSeq)*math.sqrt(self.numEmbeds))
    outputSeq=self.transformerEncoder(inputSeq,self.inputSeqMask)
    return F.log_softmax(self.decoderLinear(outputSeq),dim=-1)
  def squareSubsequentMast(self,size):
    mask=(torch.triu(torch.ones(size,size))==1).transpose(0,1)
    return mask.float().masked_fill(mask==0,float('-inf'))\
                     .masked_fill(mask==1,float(0.0))

Data And Training

class Dictionary:
  def __init__(self):
    self.wordToIndex={}; self.indexToWord=[]
  def addWord(self,word):
    if word not in self.wordToIndex:
      self.indexToWord.append(word)
      self.wordToIndex[word]=len(self.indexToWord)-1
    return self.wordToIndex[word]

class Corpus:
  def __init__(self,path,batchSize):
    self.batchSize=batchSize

    self.dictionary=Dictionary()
    self.trainData=self.fileToIndices(os.path.join(path,'train.txt'))
    self.validData=self.fileToIndices(os.path.join(path,'valid.txt'))
    self.testData=self.fileToIndices(os.path.join(path,'test.txt'))
  def fileToIndices(self,path):
    assert os.path.exists(path)
    with open(path,'r',encoding='utf8') as myFile:
      for line in myFile:
        words=line.split()+['<eos>'] # end of seq
        for word in words:
          self.dictionary.addWord(word)
    with open(path,'r',encoding='utf8') as myFile:
      fileToIndices=[]
      for line in myFile:
        words=line.split()+['<eos>']
        for word in words:
          fileToIndices.append(self.dictionary.wordToIndex[word])
      nSlices=len(fileToIndices)//self.batchSize
    return torch.tensor(fileToIndices[:nSlices*self.batchSize])\
      .type(torch.int64).view(self.batchSize,-1).t().contiguous().to(device)

# alphabet as the sequence, batchSize=4
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

class Trainer:
  def __init__(self,batchSize,seqLen,clip,learnRate,
               embedDim,nHead,feedforwardDim,nLayers,dropoutProb):
    self.seqLen=seqLen; self.clip=clip; self.learnRate=learnRate

    self.corpus=Corpus('',batchSize)
    print(self.corpus.trainData.shape)

    self.numEmbeds=len(self.corpus.dictionary.indexToWord)
    self.transformer=Transformer(self.numEmbeds,embedDim,nHead,
                                 feedforwardDim,nLayers,
                                 dropoutProb).to(device)
    self.criterion=nn.CrossEntropyLoss()

  def getSeq(self,fullData,startIndex):
    seqLen=min(self.seqLen,len(fullData)-startIndex-1)
    inputSeq =fullData[startIndex  :startIndex+seqLen]
    targetSeq=fullData[startIndex+1:startIndex+seqLen+1].view(-1)
    #print('@getSeq:inputSeq,targetSeq',inputSeq.shape,targetSeq.shape)
    return inputSeq,targetSeq
# self.seqLen=2, startIndex=0
#  inputSeq    targetSeq
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘

  def evaluate(self,fullData):
    self.transformer.eval() # turn on eval mode which disables dropout
    totalLoss=0.
    with torch.no_grad():
      for startIndex in range(0,fullData.size(0)-1,self.seqLen):
        inputSeq,targetSeq=self.getSeq(fullData,startIndex)
        outputSeq=self.transformer(inputSeq)
        outputSeqFlat=outputSeq.view(-1,self.numEmbeds)
        totalLoss+=len(inputSeq)*\
          self.criterion(outputSeqFlat,targetSeq).item()
    return totalLoss/len(fullData)

  def trainOneEpoch(self,fullData):
    self.transformer.train() # turn on training mode which enables dropout
    totalLoss=0.
    for iter, startIndex in enumerate(range(0,fullData.size(0)-1,self.seqLen)):
      inputSeq,targetSeq=self.getSeq(fullData,startIndex)
      self.transformer.zero_grad()
      outputSeq=self.transformer(inputSeq)
      outputSeqFlat=outputSeq.view(-1,self.numEmbeds)
      loss=self.criterion(outputSeqFlat,targetSeq)
      loss.backward()
      # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
      torch.nn.utils.clip_grad_norm_(self.transformer.parameters(),self.clip)
      # modify?
      for p in self.transformer.parameters():
        p.data.add_(-self.learnRate*p.grad.data)
      totalLoss+=loss.item()
      if iter%100==0 and iter>0:
        print('loss {:5.2f}'.format(totalLoss/100)); totalLoss=0
  def train(self,nEpochs):
    try:
      for i in range(nEpochs):
        print('fuck')
        self.trainOneEpoch(self.corpus.trainData)
        validLoss=self.evaluate(self.corpus.validData)
        print('-'*69)
        if not bestValidLoss or validLoss<bestValidLoss:
          with open(self.savePath,'wb') as myFile:
            torch.save(self.transformer,myFile)
          bestValidLoss=validLoss
        else:
          self.learnRate/=4.0 # anneal learnRate if no improvement
    except KeyboardInterrupt:
      print('fuck you')


device=torch.device('cude' if torch.cuda.is_available() else 'cpu')
t=Trainer(batchSize=20,seqLen=35,clip=0.5,learnRate=5.0,
          embedDim=200,nHead=2,feedforwardDim=200,nLayers=2,dropoutProb=0.2)
t.train(2)
torch.Size([104431, 20])
fuck
loss  8.01
loss  7.30
loss  6.97
loss  6.79
loss  6.70
loss  6.56
loss  6.49
loss  6.45
loss  6.39
loss  6.34
loss  6.23
loss  6.23
loss  6.18
loss  6.07
loss  6.10
loss  6.11
loss  6.05
loss  5.96
loss  5.92
loss  5.94
loss  5.83
loss  5.77
loss  5.73
loss  5.80
loss  5.75
loss  5.75
loss  5.66
loss  5.62
loss  5.64