Skip to content
Snippets Groups Projects
Commit 5a791efe authored by Kreinsen, Moritz's avatar Kreinsen, Moritz
Browse files

Update

parent a78c22b9
Branches main
No related tags found
No related merge requests found
%% Cell type:markdown id:b39e9d1f-05b5-43b4-b50a-036ae88657cd tags:
# Next-Token-Prediction
This is based on the following blog posts:
* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671
%% Cell type:code id:9c35813f-ccb3-42b3-a7e1-6069dece172f tags:
``` python
import nltk
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
```
%% Cell type:code id:b482fac0-b576-472e-ab5c-df951d0b2404 tags:
``` python
nltk.download('punkt')
```
%% Output
[nltk_data] Downloading package punkt to /home/container/nltk_data...
[nltk_data] Package punkt is already up-to-date!
True
%% Cell type:code id:f6a376c4-bd61-44d0-badc-8c451394b627 tags:
``` python
dataset = """
Is Antwerp a city?,
Is Antwerp a municipality?,
Is Antwerp in Belgium?,
What is Antwerp?,
What is the population of the city of Antwerp?,
Where is the city of Antwerp?,
Why is Antwerp important to fashion?,
Antwerp is to the east of what river?,
How many municipalities does Antwerp have?,
"""
```
%% Cell type:code id:c5831581-84e1-41d1-ae4f-b87e5bab365c tags:
``` python
def get_all_possible_sequences(text):
seq = []
words = nltk.word_tokenize(text)
total_words = len(words)
for i in range(1, total_words):
for j in range(1, len(words)-i+1):
arr = words[j-1:j+i]
seq.append((arr[:-1], arr[-1]))
return seq
def build_vocabulary(docs):
vocabulary = []
for doc in docs:
for w in nltk.word_tokenize(doc):
if w not in vocabulary:
vocabulary.append(w)
vocabulary.append('UNK')
return vocabulary
```
%% Cell type:code id:2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0 tags:
``` python
docs = []
for row in dataset.split(","):
docs.append(row.lower())
lst = []
for doc in docs:
tmp_lst = get_all_possible_sequences(doc)
lst = lst + tmp_lst
vocabulary = build_vocabulary(docs)
id2word = {idx: w for (idx, w) in enumerate(vocabulary)}
word2id = {w: idx for (idx, w) in enumerate(vocabulary)}
def seq2id(arr):
return torch.tensor([word2id[i] for i in arr])
def get_max_seq():
return len(list(set([len(i[0]) for i in lst])))
MAX_SEQ_LEN = get_max_seq()
def get_padded_x(data):
new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])
return new_data
def get_xy_vector(arr):
x = seq2id(arr[0])
y = seq2id([arr[1]])
return x, y
```
%% Cell type:code id:bb8929f7-85be-4740-be07-7dd6b4ed3086 tags:
``` python
class NextWordModel(nn.Module):
""" Prediction of Next word based on the MAX_SEQ_LEN Sequence """
def __init__(self, embedding_dim, hidden_dim, vocab_size):
super(NextWordModel, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)
self.linear = nn.Linear(hidden_dim, vocab_size)
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.gru(embeds.view(1, 1, -1))
x = self.linear(lstm_out.view(1, -1))
return x
```
%% Cell type:code id:08b6f539-75a0-4a18-b462-165ca44387d3 tags:
``` python
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
print(f'Running on {dev}')
# set the model to be copied on GPU
device = torch.device(dev)
```
%% Output
Running on cuda:0
%% Cell type:code id:cf4f67b2-e0a4-4b57-abd9-4a2e40898511 tags:
``` python
EMBEDDING_DIM = 10
NO_OF_EPOCHS = 300
HIDDEN_DIM = len(vocabulary)
model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.to(device)
for epoch in range(NO_OF_EPOCHS):
running_loss = 0.0
i = 0
for data in lst:
model.zero_grad()
x, y = get_xy_vector(data)
# convert to max seq length with padding
x = get_padded_x(x)
x = x.to(device)
y = y.to(device)
predicted = model(x)
loss = loss_function(predicted, y)
loss.backward()
optimizer.step()
running_loss += loss
i += 1
if i % 100 == 0:
#print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')
running_loss = 0
print('Finished')
```
%% Output
Finished
%% Cell type:code id:84e84dfc-5b7a-47c2-96cd-304c5630a002 tags:
``` python
with torch.no_grad():
print('Type something here . . .')
while True:
inp = input("")
inp = inp.strip()
if inp == "q":
break
tokens = nltk.word_tokenize(inp.lower())
x = seq2id(tokens)
x = get_padded_x(x)
x = x.to(device)
predicted = model(x).to(device)
predicted = predicted[0].cpu().numpy()
print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')
```
%% Output
Type something here . . .
what
Answer :what is
what is
Answer :what is the
what is the
Answer :what is the population
what is the population
Answer :what is the population of
what is the population of
Answer :what is the population of the
what is the population of the
Answer :what is the population of the city
%% Cell type:code id:660e6502-7548-47df-a645-a1e2d33178f6 tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment