Skip to content
Snippets Groups Projects
Commit 0b0ed96a authored by Kreinsen, Moritz's avatar Kreinsen, Moritz
Browse files

Update

parents 5a791efe 4ce825fa
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:b39e9d1f-05b5-43b4-b50a-036ae88657cd tags: %% Cell type:markdown id:0f901b18 tags:
# Next-Token-Prediction # Next-Token-Prediction
This is based on the following blog posts: This is based on the following blog posts:
* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286
* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671 * Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671
%% Cell type:code id:9c35813f-ccb3-42b3-a7e1-6069dece172f tags: %% Cell type:code id:b2471e25 tags:
``` python ``` python
import nltk import nltk
import pandas as pd import pandas as pd
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
``` ```
%% Cell type:code id:b482fac0-b576-472e-ab5c-df951d0b2404 tags: %% Cell type:code id:c0febc54 tags:
``` python ``` python
nltk.download('punkt') nltk.download('punkt')
``` ```
%% Output %% Cell type:code id:8a781aa1 tags:
[nltk_data] Downloading package punkt to /home/container/nltk_data...
[nltk_data] Package punkt is already up-to-date!
True
%% Cell type:code id:f6a376c4-bd61-44d0-badc-8c451394b627 tags:
``` python ``` python
dataset = """ dataset = """
Is Antwerp a city?, Is Antwerp a city?,
Is Antwerp a municipality?, Is Antwerp a municipality?,
Is Antwerp in Belgium?, Is Antwerp in Belgium?,
What is Antwerp?, What is Antwerp?,
What is the population of the city of Antwerp?, What is the population of the city of Antwerp?,
Where is the city of Antwerp?, Where is the city of Antwerp?,
Why is Antwerp important to fashion?, Why is Antwerp important to fashion?,
Antwerp is to the east of what river?, Antwerp is to the east of what river?,
How many municipalities does Antwerp have?, How many municipalities does Antwerp have?,
""" """
``` ```
%% Cell type:code id:c5831581-84e1-41d1-ae4f-b87e5bab365c tags: %% Cell type:code id:39255d82 tags:
``` python ``` python
def get_all_possible_sequences(text): def get_all_possible_sequences(text):
seq = [] seq = []
words = nltk.word_tokenize(text) words = nltk.word_tokenize(text)
total_words = len(words) total_words = len(words)
for i in range(1, total_words): for i in range(1, total_words):
for j in range(1, len(words)-i+1): for j in range(1, len(words)-i+1):
arr = words[j-1:j+i] arr = words[j-1:j+i]
seq.append((arr[:-1], arr[-1])) seq.append((arr[:-1], arr[-1]))
return seq return seq
def build_vocabulary(docs): def build_vocabulary(docs):
vocabulary = [] vocabulary = []
for doc in docs: for doc in docs:
for w in nltk.word_tokenize(doc): for w in nltk.word_tokenize(doc):
if w not in vocabulary: if w not in vocabulary:
vocabulary.append(w) vocabulary.append(w)
vocabulary.append('UNK') vocabulary.append('UNK')
return vocabulary return vocabulary
``` ```
%% Cell type:code id:2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0 tags: %% Cell type:code id:edb54d0d tags:
``` python ``` python
docs = [] docs = []
for row in dataset.split(","): for row in dataset.split(","):
docs.append(row.lower()) docs.append(row.lower())
lst = [] lst = []
for doc in docs: for doc in docs:
tmp_lst = get_all_possible_sequences(doc) tmp_lst = get_all_possible_sequences(doc)
lst = lst + tmp_lst lst = lst + tmp_lst
vocabulary = build_vocabulary(docs) vocabulary = build_vocabulary(docs)
id2word = {idx: w for (idx, w) in enumerate(vocabulary)} id2word = {idx: w for (idx, w) in enumerate(vocabulary)}
word2id = {w: idx for (idx, w) in enumerate(vocabulary)} word2id = {w: idx for (idx, w) in enumerate(vocabulary)}
def seq2id(arr): def seq2id(arr):
return torch.tensor([word2id[i] for i in arr]) return torch.tensor([word2id[i] for i in arr])
def get_max_seq(): def get_max_seq():
return len(list(set([len(i[0]) for i in lst]))) return len(list(set([len(i[0]) for i in lst])))
MAX_SEQ_LEN = get_max_seq() MAX_SEQ_LEN = get_max_seq()
def get_padded_x(data): def get_padded_x(data):
new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK']) new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])
return new_data return new_data
def get_xy_vector(arr): def get_xy_vector(arr):
x = seq2id(arr[0]) x = seq2id(arr[0])
y = seq2id([arr[1]]) y = seq2id([arr[1]])
return x, y return x, y
``` ```
%% Cell type:code id:bb8929f7-85be-4740-be07-7dd6b4ed3086 tags: %% Cell type:code id:431fd558 tags:
``` python ``` python
class NextWordModel(nn.Module): class NextWordModel(nn.Module):
""" Prediction of Next word based on the MAX_SEQ_LEN Sequence """ """ Prediction of Next word based on the MAX_SEQ_LEN Sequence """
def __init__(self, embedding_dim, hidden_dim, vocab_size): def __init__(self, embedding_dim, hidden_dim, vocab_size):
super(NextWordModel, self).__init__() super(NextWordModel, self).__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim) self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)
self.linear = nn.Linear(hidden_dim, vocab_size) self.linear = nn.Linear(hidden_dim, vocab_size)
def forward(self, sentence): def forward(self, sentence):
embeds = self.word_embeddings(sentence) embeds = self.word_embeddings(sentence)
lstm_out, _ = self.gru(embeds.view(1, 1, -1)) lstm_out, _ = self.gru(embeds.view(1, 1, -1))
x = self.linear(lstm_out.view(1, -1)) x = self.linear(lstm_out.view(1, -1))
return x return x
``` ```
%% Cell type:code id:08b6f539-75a0-4a18-b462-165ca44387d3 tags: %% Cell type:code id:d891422e tags:
``` python ``` python
if torch.cuda.is_available(): if torch.cuda.is_available():
dev = "cuda:0" dev = "cuda:0"
else: else:
dev = "cpu" dev = "cpu"
print(f'Running on {dev}') print(f'Running on {dev}')
# set the model to be copied on GPU # set the model to be copied on GPU
device = torch.device(dev) device = torch.device(dev)
``` ```
%% Output %% Cell type:code id:7c231c92 tags:
Running on cuda:0
%% Cell type:code id:cf4f67b2-e0a4-4b57-abd9-4a2e40898511 tags:
``` python ``` python
EMBEDDING_DIM = 10 EMBEDDING_DIM = 10
NO_OF_EPOCHS = 300 NO_OF_EPOCHS = 300
HIDDEN_DIM = len(vocabulary) HIDDEN_DIM = len(vocabulary)
model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary)) model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1) optimizer = optim.SGD(model.parameters(), lr=0.1)
model.to(device) model.to(device)
for epoch in range(NO_OF_EPOCHS): for epoch in range(NO_OF_EPOCHS):
running_loss = 0.0 running_loss = 0.0
i = 0 i = 0
for data in lst: for data in lst:
model.zero_grad() model.zero_grad()
x, y = get_xy_vector(data) x, y = get_xy_vector(data)
# convert to max seq length with padding # convert to max seq length with padding
x = get_padded_x(x) x = get_padded_x(x)
x = x.to(device) x = x.to(device)
y = y.to(device) y = y.to(device)
predicted = model(x) predicted = model(x)
loss = loss_function(predicted, y) loss = loss_function(predicted, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
running_loss += loss running_loss += loss
i += 1 i += 1
if i % 100 == 0: if i % 100 == 0:
#print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}') #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')
running_loss = 0 running_loss = 0
print('Finished') print('Finished')
``` ```
%% Output %% Cell type:code id:9d0ff01d tags:
Finished
%% Cell type:code id:84e84dfc-5b7a-47c2-96cd-304c5630a002 tags:
``` python ``` python
with torch.no_grad(): with torch.no_grad():
print('Type something here . . .') print('Type something here . . .')
while True: while True:
inp = input("") inp = input("")
inp = inp.strip() inp = inp.strip()
if inp == "q": if inp == "q":
break break
tokens = nltk.word_tokenize(inp.lower()) tokens = nltk.word_tokenize(inp.lower())
x = seq2id(tokens) x = seq2id(tokens)
x = get_padded_x(x) x = get_padded_x(x)
x = x.to(device) x = x.to(device)
predicted = model(x).to(device) predicted = model(x).to(device)
predicted = predicted[0].cpu().numpy() predicted = predicted[0].cpu().numpy()
print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ') print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')
``` ```
%% Output %% Cell type:code id:c1e828f0 tags:
Type something here . . .
what
Answer :what is
what is
Answer :what is the
what is the
Answer :what is the population
what is the population
Answer :what is the population of
what is the population of
Answer :what is the population of the
what is the population of the
Answer :what is the population of the city
%% Cell type:code id:660e6502-7548-47df-a645-a1e2d33178f6 tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment