Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
B
Binder Test
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Kreinsen, Moritz
Binder Test
Commits
5a791efe
Commit
5a791efe
authored
2 years ago
by
Kreinsen, Moritz
Browse files
Options
Downloads
Patches
Plain Diff
Update
parent
a78c22b9
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
Next-Token-Prediction.ipynb
+395
-0
395 additions, 0 deletions
Next-Token-Prediction.ipynb
with
395 additions
and
0 deletions
Next-Token-Prediction.ipynb
0 → 100644
+
395
−
0
View file @
5a791efe
{
"cells": [
{
"cell_type": "markdown",
"id": "b39e9d1f-05b5-43b4-b50a-036ae88657cd",
"metadata": {},
"source": [
"# Next-Token-Prediction\n",
"This is based on the following blog posts: \n",
"* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9c35813f-ccb3-42b3-a7e1-6069dece172f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import nltk\n",
"import pandas as pd\n",
"import torch\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b482fac0-b576-472e-ab5c-df951d0b2404",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /home/container/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nltk.download('punkt')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f6a376c4-bd61-44d0-badc-8c451394b627",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset = \"\"\"\n",
" Is Antwerp a city?,\n",
" Is Antwerp a municipality?,\n",
" Is Antwerp in Belgium?,\n",
" What is Antwerp?,\n",
" What is the population of the city of Antwerp?,\n",
" Where is the city of Antwerp?,\n",
" Why is Antwerp important to fashion?,\n",
" Antwerp is to the east of what river?,\n",
" How many municipalities does Antwerp have?,\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c5831581-84e1-41d1-ae4f-b87e5bab365c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def get_all_possible_sequences(text):\n",
" seq = []\n",
" words = nltk.word_tokenize(text)\n",
" total_words = len(words)\n",
" for i in range(1, total_words):\n",
" for j in range(1, len(words)-i+1):\n",
" arr = words[j-1:j+i]\n",
" seq.append((arr[:-1], arr[-1]))\n",
" return seq\n",
"def build_vocabulary(docs):\n",
" vocabulary = []\n",
" for doc in docs:\n",
" for w in nltk.word_tokenize(doc):\n",
" if w not in vocabulary:\n",
" vocabulary.append(w)\n",
" vocabulary.append('UNK')\n",
" return vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"docs = []\n",
"for row in dataset.split(\",\"):\n",
" docs.append(row.lower())\n",
"\n",
"lst = []\n",
"for doc in docs:\n",
" tmp_lst = get_all_possible_sequences(doc)\n",
" lst = lst + tmp_lst\n",
"\n",
"vocabulary = build_vocabulary(docs)\n",
"id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n",
"word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n",
"def seq2id(arr):\n",
" return torch.tensor([word2id[i] for i in arr])\n",
"def get_max_seq():\n",
" return len(list(set([len(i[0]) for i in lst])))\n",
"MAX_SEQ_LEN = get_max_seq()\n",
"def get_padded_x(data):\n",
" new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n",
" return new_data\n",
"def get_xy_vector(arr):\n",
" x = seq2id(arr[0])\n",
" y = seq2id([arr[1]])\n",
" return x, y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bb8929f7-85be-4740-be07-7dd6b4ed3086",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class NextWordModel(nn.Module):\n",
" \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n",
" def __init__(self, embedding_dim, hidden_dim, vocab_size):\n",
" super(NextWordModel, self).__init__()\n",
" self.hidden_dim = hidden_dim\n",
" self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
" self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n",
" self.linear = nn.Linear(hidden_dim, vocab_size)\n",
"\n",
" def forward(self, sentence):\n",
" embeds = self.word_embeddings(sentence)\n",
" lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n",
" x = self.linear(lstm_out.view(1, -1))\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "08b6f539-75a0-4a18-b462-165ca44387d3",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on cuda:0\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" dev = \"cuda:0\"\n",
"else:\n",
" dev = \"cpu\"\n",
"print(f'Running on {dev}')\n",
"# set the model to be copied on GPU\n",
"device = torch.device(dev)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cf4f67b2-e0a4-4b57-abd9-4a2e40898511",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished\n"
]
}
],
"source": [
"EMBEDDING_DIM = 10\n",
"NO_OF_EPOCHS = 300\n",
"HIDDEN_DIM = len(vocabulary)\n",
"model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n",
"loss_function = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
"model.to(device)\n",
"for epoch in range(NO_OF_EPOCHS):\n",
" running_loss = 0.0\n",
" i = 0\n",
" for data in lst:\n",
" model.zero_grad()\n",
" x, y = get_xy_vector(data)\n",
"# convert to max seq length with padding\n",
" x = get_padded_x(x)\n",
"\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
"\n",
" predicted = model(x)\n",
"\n",
" loss = loss_function(predicted, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss\n",
" i += 1\n",
" if i % 100 == 0:\n",
" #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n",
" running_loss = 0\n",
"print('Finished')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84e84dfc-5b7a-47c2-96cd-304c5630a002",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Type something here . . .\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is \n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what is\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is the \n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what is the\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is the population \n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what is the population\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is the population of \n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what is the population of\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is the population of the \n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
" what is the population of the\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer :what is the population of the city \n"
]
}
],
"source": [
"with torch.no_grad():\n",
" print('Type something here . . .')\n",
" while True:\n",
" inp = input(\"\")\n",
" inp = inp.strip()\n",
" if inp == \"q\":\n",
" break\n",
"\n",
" tokens = nltk.word_tokenize(inp.lower())\n",
" x = seq2id(tokens)\n",
" x = get_padded_x(x)\n",
"\n",
" x = x.to(device)\n",
" predicted = model(x).to(device)\n",
"\n",
" predicted = predicted[0].cpu().numpy()\n",
"\n",
" print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "660e6502-7548-47df-a645-a1e2d33178f6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
%% Cell type:markdown id:b39e9d1f-05b5-43b4-b50a-036ae88657cd tags:
# Next-Token-Prediction
This is based on the following blog posts:
*
Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671
%% Cell type:code id:9c35813f-ccb3-42b3-a7e1-6069dece172f tags:
```
python
import
nltk
import
pandas
as
pd
import
torch
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
```
%% Cell type:code id:b482fac0-b576-472e-ab5c-df951d0b2404 tags:
```
python
nltk
.
download
(
'
punkt
'
)
```
%% Output
[nltk_data] Downloading package punkt to /home/container/nltk_data...
[nltk_data] Package punkt is already up-to-date!
True
%% Cell type:code id:f6a376c4-bd61-44d0-badc-8c451394b627 tags:
```
python
dataset
=
"""
Is Antwerp a city?,
Is Antwerp a municipality?,
Is Antwerp in Belgium?,
What is Antwerp?,
What is the population of the city of Antwerp?,
Where is the city of Antwerp?,
Why is Antwerp important to fashion?,
Antwerp is to the east of what river?,
How many municipalities does Antwerp have?,
"""
```
%% Cell type:code id:c5831581-84e1-41d1-ae4f-b87e5bab365c tags:
```
python
def
get_all_possible_sequences
(
text
):
seq
=
[]
words
=
nltk
.
word_tokenize
(
text
)
total_words
=
len
(
words
)
for
i
in
range
(
1
,
total_words
):
for
j
in
range
(
1
,
len
(
words
)
-
i
+
1
):
arr
=
words
[
j
-
1
:
j
+
i
]
seq
.
append
((
arr
[:
-
1
],
arr
[
-
1
]))
return
seq
def
build_vocabulary
(
docs
):
vocabulary
=
[]
for
doc
in
docs
:
for
w
in
nltk
.
word_tokenize
(
doc
):
if
w
not
in
vocabulary
:
vocabulary
.
append
(
w
)
vocabulary
.
append
(
'
UNK
'
)
return
vocabulary
```
%% Cell type:code id:2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0 tags:
```
python
docs
=
[]
for
row
in
dataset
.
split
(
"
,
"
):
docs
.
append
(
row
.
lower
())
lst
=
[]
for
doc
in
docs
:
tmp_lst
=
get_all_possible_sequences
(
doc
)
lst
=
lst
+
tmp_lst
vocabulary
=
build_vocabulary
(
docs
)
id2word
=
{
idx
:
w
for
(
idx
,
w
)
in
enumerate
(
vocabulary
)}
word2id
=
{
w
:
idx
for
(
idx
,
w
)
in
enumerate
(
vocabulary
)}
def
seq2id
(
arr
):
return
torch
.
tensor
([
word2id
[
i
]
for
i
in
arr
])
def
get_max_seq
():
return
len
(
list
(
set
([
len
(
i
[
0
])
for
i
in
lst
])))
MAX_SEQ_LEN
=
get_max_seq
()
def
get_padded_x
(
data
):
new_data
=
F
.
pad
(
input
=
data
.
view
(
1
,
-
1
),
pad
=
(
0
,
MAX_SEQ_LEN
-
data
.
shape
[
0
],
0
,
0
),
mode
=
'
constant
'
,
value
=
word2id
[
'
UNK
'
])
return
new_data
def
get_xy_vector
(
arr
):
x
=
seq2id
(
arr
[
0
])
y
=
seq2id
([
arr
[
1
]])
return
x
,
y
```
%% Cell type:code id:bb8929f7-85be-4740-be07-7dd6b4ed3086 tags:
```
python
class
NextWordModel
(
nn
.
Module
):
"""
Prediction of Next word based on the MAX_SEQ_LEN Sequence
"""
def
__init__
(
self
,
embedding_dim
,
hidden_dim
,
vocab_size
):
super
(
NextWordModel
,
self
).
__init__
()
self
.
hidden_dim
=
hidden_dim
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embedding_dim
)
self
.
gru
=
nn
.
GRU
(
embedding_dim
*
MAX_SEQ_LEN
,
hidden_dim
)
self
.
linear
=
nn
.
Linear
(
hidden_dim
,
vocab_size
)
def
forward
(
self
,
sentence
):
embeds
=
self
.
word_embeddings
(
sentence
)
lstm_out
,
_
=
self
.
gru
(
embeds
.
view
(
1
,
1
,
-
1
))
x
=
self
.
linear
(
lstm_out
.
view
(
1
,
-
1
))
return
x
```
%% Cell type:code id:08b6f539-75a0-4a18-b462-165ca44387d3 tags:
```
python
if
torch
.
cuda
.
is_available
():
dev
=
"
cuda:0
"
else
:
dev
=
"
cpu
"
print
(
f
'
Running on
{
dev
}
'
)
# set the model to be copied on GPU
device
=
torch
.
device
(
dev
)
```
%% Output
Running on cuda:0
%% Cell type:code id:cf4f67b2-e0a4-4b57-abd9-4a2e40898511 tags:
```
python
EMBEDDING_DIM
=
10
NO_OF_EPOCHS
=
300
HIDDEN_DIM
=
len
(
vocabulary
)
model
=
NextWordModel
(
EMBEDDING_DIM
,
HIDDEN_DIM
,
len
(
vocabulary
))
loss_function
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
)
model
.
to
(
device
)
for
epoch
in
range
(
NO_OF_EPOCHS
):
running_loss
=
0.0
i
=
0
for
data
in
lst
:
model
.
zero_grad
()
x
,
y
=
get_xy_vector
(
data
)
# convert to max seq length with padding
x
=
get_padded_x
(
x
)
x
=
x
.
to
(
device
)
y
=
y
.
to
(
device
)
predicted
=
model
(
x
)
loss
=
loss_function
(
predicted
,
y
)
loss
.
backward
()
optimizer
.
step
()
running_loss
+=
loss
i
+=
1
if
i
%
100
==
0
:
#print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')
running_loss
=
0
print
(
'
Finished
'
)
```
%% Output
Finished
%% Cell type:code id:84e84dfc-5b7a-47c2-96cd-304c5630a002 tags:
```
python
with
torch
.
no_grad
():
print
(
'
Type something here . . .
'
)
while
True
:
inp
=
input
(
""
)
inp
=
inp
.
strip
()
if
inp
==
"
q
"
:
break
tokens
=
nltk
.
word_tokenize
(
inp
.
lower
())
x
=
seq2id
(
tokens
)
x
=
get_padded_x
(
x
)
x
=
x
.
to
(
device
)
predicted
=
model
(
x
).
to
(
device
)
predicted
=
predicted
[
0
].
cpu
().
numpy
()
print
(
f
'
Answer:
{
inp
}
{
id2word
[
np
.
argmax
(
predicted
)]
}
'
)
```
%% Output
Type something here . . .
what
Answer :what is
what is
Answer :what is the
what is the
Answer :what is the population
what is the population
Answer :what is the population of
what is the population of
Answer :what is the population of the
what is the population of the
Answer :what is the population of the city
%% Cell type:code id:660e6502-7548-47df-a645-a1e2d33178f6 tags:
```
python
```
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment