Skip to content
Snippets Groups Projects
Commit 70592ea5 authored by Kreinsen, Moritz's avatar Kreinsen, Moritz
Browse files

Update

parent 0b0ed96a
Branches main
No related tags found
No related merge requests found
%% Cell type:markdown id:0f901b18 tags: %% Cell type:markdown id:0f901b18 tags:
# Next-Token-Prediction # Next-Token-Prediction
This is based on the following blog posts: This is based on the following blog posts:
* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286
* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671 * Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671
* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286
%% Cell type:code id:b2471e25 tags: %% Cell type:code id:b2471e25 tags:
``` python ``` python
import nltk import nltk
import pandas as pd import pandas as pd
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
``` ```
%% Cell type:code id:c0febc54 tags: %% Cell type:code id:c0febc54 tags:
``` python ``` python
nltk.download('punkt') nltk.download('punkt')
``` ```
%% Cell type:code id:8a781aa1 tags: %% Cell type:code id:8a781aa1 tags:
``` python ``` python
dataset = """ dataset = """
Is Antwerp a city?, Is Antwerp a city?,
Is Antwerp a municipality?, Is Antwerp a municipality?,
Is Antwerp in Belgium?, Is Antwerp in Belgium?,
What is Antwerp?, What is Antwerp?,
What is the population of the city of Antwerp?, What is the population of the city of Antwerp?,
Where is the city of Antwerp?, Where is the city of Antwerp?,
Why is Antwerp important to fashion?, Why is Antwerp important to fashion?,
Antwerp is to the east of what river?, Antwerp is to the east of what river?,
How many municipalities does Antwerp have?, How many municipalities does Antwerp have?,
""" """
``` ```
%% Cell type:code id:39255d82 tags: %% Cell type:code id:39255d82 tags:
``` python ``` python
def get_all_possible_sequences(text): def get_all_possible_sequences(text):
seq = [] seq = []
words = nltk.word_tokenize(text) words = nltk.word_tokenize(text)
total_words = len(words) total_words = len(words)
for i in range(1, total_words): for i in range(1, total_words):
for j in range(1, len(words)-i+1): for j in range(1, len(words)-i+1):
arr = words[j-1:j+i] arr = words[j-1:j+i]
seq.append((arr[:-1], arr[-1])) seq.append((arr[:-1], arr[-1]))
return seq return seq
def build_vocabulary(docs): def build_vocabulary(docs):
vocabulary = [] vocabulary = []
for doc in docs: for doc in docs:
for w in nltk.word_tokenize(doc): for w in nltk.word_tokenize(doc):
if w not in vocabulary: if w not in vocabulary:
vocabulary.append(w) vocabulary.append(w)
vocabulary.append('UNK') vocabulary.append('UNK')
return vocabulary return vocabulary
``` ```
%% Cell type:code id:edb54d0d tags: %% Cell type:code id:edb54d0d tags:
``` python ``` python
docs = [] docs = []
for row in dataset.split(","): for row in dataset.split(","):
docs.append(row.lower()) docs.append(row.lower())
lst = [] lst = []
for doc in docs: for doc in docs:
tmp_lst = get_all_possible_sequences(doc) tmp_lst = get_all_possible_sequences(doc)
lst = lst + tmp_lst lst = lst + tmp_lst
vocabulary = build_vocabulary(docs) vocabulary = build_vocabulary(docs)
id2word = {idx: w for (idx, w) in enumerate(vocabulary)} id2word = {idx: w for (idx, w) in enumerate(vocabulary)}
word2id = {w: idx for (idx, w) in enumerate(vocabulary)} word2id = {w: idx for (idx, w) in enumerate(vocabulary)}
def seq2id(arr): def seq2id(arr):
return torch.tensor([word2id[i] for i in arr]) return torch.tensor([word2id[i] for i in arr])
def get_max_seq(): def get_max_seq():
return len(list(set([len(i[0]) for i in lst]))) return len(list(set([len(i[0]) for i in lst])))
MAX_SEQ_LEN = get_max_seq() MAX_SEQ_LEN = get_max_seq()
def get_padded_x(data): def get_padded_x(data):
new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK']) new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])
return new_data return new_data
def get_xy_vector(arr): def get_xy_vector(arr):
x = seq2id(arr[0]) x = seq2id(arr[0])
y = seq2id([arr[1]]) y = seq2id([arr[1]])
return x, y return x, y
``` ```
%% Cell type:code id:431fd558 tags: %% Cell type:code id:431fd558 tags:
``` python ``` python
class NextWordModel(nn.Module): class NextWordModel(nn.Module):
""" Prediction of Next word based on the MAX_SEQ_LEN Sequence """ """ Prediction of Next word based on the MAX_SEQ_LEN Sequence """
def __init__(self, embedding_dim, hidden_dim, vocab_size): def __init__(self, embedding_dim, hidden_dim, vocab_size):
super(NextWordModel, self).__init__() super(NextWordModel, self).__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim) self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)
self.linear = nn.Linear(hidden_dim, vocab_size) self.linear = nn.Linear(hidden_dim, vocab_size)
def forward(self, sentence): def forward(self, sentence):
embeds = self.word_embeddings(sentence) embeds = self.word_embeddings(sentence)
lstm_out, _ = self.gru(embeds.view(1, 1, -1)) lstm_out, _ = self.gru(embeds.view(1, 1, -1))
x = self.linear(lstm_out.view(1, -1)) x = self.linear(lstm_out.view(1, -1))
return x return x
``` ```
%% Cell type:code id:d891422e tags: %% Cell type:code id:d891422e tags:
``` python ``` python
if torch.cuda.is_available(): if torch.cuda.is_available():
dev = "cuda:0" dev = "cuda:0"
else: else:
dev = "cpu" dev = "cpu"
print(f'Running on {dev}') print(f'Running on {dev}')
# set the model to be copied on GPU # set the model to be copied on GPU
device = torch.device(dev) device = torch.device(dev)
``` ```
%% Cell type:code id:7c231c92 tags: %% Cell type:code id:7c231c92 tags:
``` python ``` python
EMBEDDING_DIM = 10 EMBEDDING_DIM = 10
NO_OF_EPOCHS = 300 NO_OF_EPOCHS = 300
HIDDEN_DIM = len(vocabulary) HIDDEN_DIM = len(vocabulary)
model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary)) model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1) optimizer = optim.SGD(model.parameters(), lr=0.1)
model.to(device) model.to(device)
for epoch in range(NO_OF_EPOCHS): for epoch in range(NO_OF_EPOCHS):
running_loss = 0.0 running_loss = 0.0
i = 0 i = 0
for data in lst: for data in lst:
model.zero_grad() model.zero_grad()
x, y = get_xy_vector(data) x, y = get_xy_vector(data)
# convert to max seq length with padding # convert to max seq length with padding
x = get_padded_x(x) x = get_padded_x(x)
x = x.to(device) x = x.to(device)
y = y.to(device) y = y.to(device)
predicted = model(x) predicted = model(x)
loss = loss_function(predicted, y) loss = loss_function(predicted, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
running_loss += loss running_loss += loss
i += 1 i += 1
if i % 100 == 0: if i % 100 == 0:
#print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}') #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')
running_loss = 0 running_loss = 0
print('Finished') print('Finished')
``` ```
%% Cell type:code id:9d0ff01d tags: %% Cell type:code id:9d0ff01d tags:
``` python ``` python
with torch.no_grad(): with torch.no_grad():
print('Type something here . . .') print('Type something here . . .')
while True: while True:
inp = input("") inp = input("")
inp = inp.strip() inp = inp.strip()
if inp == "q": if inp == "q":
break break
tokens = nltk.word_tokenize(inp.lower()) tokens = nltk.word_tokenize(inp.lower())
x = seq2id(tokens) x = seq2id(tokens)
x = get_padded_x(x) x = get_padded_x(x)
x = x.to(device) x = x.to(device)
predicted = model(x).to(device) predicted = model(x).to(device)
predicted = predicted[0].cpu().numpy() predicted = predicted[0].cpu().numpy()
print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ') print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')
``` ```
%% Cell type:code id:c1e828f0 tags: %% Cell type:code id:c1e828f0 tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment