Skip to content
Snippets Groups Projects
Commit 0b0ed96a authored by Kreinsen, Moritz's avatar Kreinsen, Moritz
Browse files

Update

parents 5a791efe 4ce825fa
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id:b39e9d1f-05b5-43b4-b50a-036ae88657cd tags:
%% Cell type:markdown id:0f901b18 tags:
# Next-Token-Prediction
This is based on the following blog posts:
* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286
* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671
%% Cell type:code id:9c35813f-ccb3-42b3-a7e1-6069dece172f tags:
%% Cell type:code id:b2471e25 tags:
``` python
import nltk
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
```
%% Cell type:code id:b482fac0-b576-472e-ab5c-df951d0b2404 tags:
%% Cell type:code id:c0febc54 tags:
``` python
nltk.download('punkt')
```
%% Output
[nltk_data] Downloading package punkt to /home/container/nltk_data...
[nltk_data] Package punkt is already up-to-date!
True
%% Cell type:code id:f6a376c4-bd61-44d0-badc-8c451394b627 tags:
%% Cell type:code id:8a781aa1 tags:
``` python
dataset = """
Is Antwerp a city?,
Is Antwerp a municipality?,
Is Antwerp in Belgium?,
What is Antwerp?,
What is the population of the city of Antwerp?,
Where is the city of Antwerp?,
Why is Antwerp important to fashion?,
Antwerp is to the east of what river?,
How many municipalities does Antwerp have?,
"""
```
%% Cell type:code id:c5831581-84e1-41d1-ae4f-b87e5bab365c tags:
%% Cell type:code id:39255d82 tags:
``` python
def get_all_possible_sequences(text):
seq = []
words = nltk.word_tokenize(text)
total_words = len(words)
for i in range(1, total_words):
for j in range(1, len(words)-i+1):
arr = words[j-1:j+i]
seq.append((arr[:-1], arr[-1]))
return seq
def build_vocabulary(docs):
vocabulary = []
for doc in docs:
for w in nltk.word_tokenize(doc):
if w not in vocabulary:
vocabulary.append(w)
vocabulary.append('UNK')
return vocabulary
```
%% Cell type:code id:2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0 tags:
%% Cell type:code id:edb54d0d tags:
``` python
docs = []
for row in dataset.split(","):
docs.append(row.lower())
lst = []
for doc in docs:
tmp_lst = get_all_possible_sequences(doc)
lst = lst + tmp_lst
vocabulary = build_vocabulary(docs)
id2word = {idx: w for (idx, w) in enumerate(vocabulary)}
word2id = {w: idx for (idx, w) in enumerate(vocabulary)}
def seq2id(arr):
return torch.tensor([word2id[i] for i in arr])
def get_max_seq():
return len(list(set([len(i[0]) for i in lst])))
MAX_SEQ_LEN = get_max_seq()
def get_padded_x(data):
new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])
return new_data
def get_xy_vector(arr):
x = seq2id(arr[0])
y = seq2id([arr[1]])
return x, y
```
%% Cell type:code id:bb8929f7-85be-4740-be07-7dd6b4ed3086 tags:
%% Cell type:code id:431fd558 tags:
``` python
class NextWordModel(nn.Module):
""" Prediction of Next word based on the MAX_SEQ_LEN Sequence """
def __init__(self, embedding_dim, hidden_dim, vocab_size):
super(NextWordModel, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)
self.linear = nn.Linear(hidden_dim, vocab_size)
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.gru(embeds.view(1, 1, -1))
x = self.linear(lstm_out.view(1, -1))
return x
```
%% Cell type:code id:08b6f539-75a0-4a18-b462-165ca44387d3 tags:
%% Cell type:code id:d891422e tags:
``` python
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
print(f'Running on {dev}')
# set the model to be copied on GPU
device = torch.device(dev)
```
%% Output
Running on cuda:0
%% Cell type:code id:cf4f67b2-e0a4-4b57-abd9-4a2e40898511 tags:
%% Cell type:code id:7c231c92 tags:
``` python
EMBEDDING_DIM = 10
NO_OF_EPOCHS = 300
HIDDEN_DIM = len(vocabulary)
model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
model.to(device)
for epoch in range(NO_OF_EPOCHS):
running_loss = 0.0
i = 0
for data in lst:
model.zero_grad()
x, y = get_xy_vector(data)
# convert to max seq length with padding
x = get_padded_x(x)
x = x.to(device)
y = y.to(device)
predicted = model(x)
loss = loss_function(predicted, y)
loss.backward()
optimizer.step()
running_loss += loss
i += 1
if i % 100 == 0:
#print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')
running_loss = 0
print('Finished')
```
%% Output
Finished
%% Cell type:code id:84e84dfc-5b7a-47c2-96cd-304c5630a002 tags:
%% Cell type:code id:9d0ff01d tags:
``` python
with torch.no_grad():
print('Type something here . . .')
while True:
inp = input("")
inp = inp.strip()
if inp == "q":
break
tokens = nltk.word_tokenize(inp.lower())
x = seq2id(tokens)
x = get_padded_x(x)
x = x.to(device)
predicted = model(x).to(device)
predicted = predicted[0].cpu().numpy()
print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')
```
%% Output
Type something here . . .
what
Answer :what is
what is
Answer :what is the
what is the
Answer :what is the population
what is the population
Answer :what is the population of
what is the population of
Answer :what is the population of the
what is the population of the
Answer :what is the population of the city
%% Cell type:code id:660e6502-7548-47df-a645-a1e2d33178f6 tags:
%% Cell type:code id:c1e828f0 tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment