From 5a791efed1fc1537176a765647b5bd76a76f4e1c Mon Sep 17 00:00:00 2001 From: Moritz Kreinsen <moritz.kreinsen@uni-hamburg.de> Date: Mon, 6 Feb 2023 11:18:21 +0000 Subject: [PATCH] Update --- Next-Token-Prediction.ipynb | 395 ++++++++++++++++++++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 Next-Token-Prediction.ipynb diff --git a/Next-Token-Prediction.ipynb b/Next-Token-Prediction.ipynb new file mode 100644 index 0000000..fcfd817 --- /dev/null +++ b/Next-Token-Prediction.ipynb @@ -0,0 +1,395 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b39e9d1f-05b5-43b4-b50a-036ae88657cd", + "metadata": {}, + "source": [ + "# Next-Token-Prediction\n", + "This is based on the following blog posts: \n", + "* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9c35813f-ccb3-42b3-a7e1-6069dece172f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import nltk\n", + "import pandas as pd\n", + "import torch\n", + "import numpy as np\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b482fac0-b576-472e-ab5c-df951d0b2404", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package punkt to /home/container/nltk_data...\n", + "[nltk_data] Package punkt is already up-to-date!\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nltk.download('punkt')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f6a376c4-bd61-44d0-badc-8c451394b627", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dataset = \"\"\"\n", + " Is Antwerp a city?,\n", + " Is Antwerp a municipality?,\n", + " Is Antwerp in Belgium?,\n", + " What is Antwerp?,\n", + " What is the population of the city of Antwerp?,\n", + " Where is the city of Antwerp?,\n", + " Why is Antwerp important to fashion?,\n", + " Antwerp is to the east of what river?,\n", + " How many municipalities does Antwerp have?,\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c5831581-84e1-41d1-ae4f-b87e5bab365c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def get_all_possible_sequences(text):\n", + " seq = []\n", + " words = nltk.word_tokenize(text)\n", + " total_words = len(words)\n", + " for i in range(1, total_words):\n", + " for j in range(1, len(words)-i+1):\n", + " arr = words[j-1:j+i]\n", + " seq.append((arr[:-1], arr[-1]))\n", + " return seq\n", + "def build_vocabulary(docs):\n", + " vocabulary = []\n", + " for doc in docs:\n", + " for w in nltk.word_tokenize(doc):\n", + " if w not in vocabulary:\n", + " vocabulary.append(w)\n", + " vocabulary.append('UNK')\n", + " return vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "docs = []\n", + "for row in dataset.split(\",\"):\n", + " docs.append(row.lower())\n", + "\n", + "lst = []\n", + "for doc in docs:\n", + " tmp_lst = get_all_possible_sequences(doc)\n", + " lst = lst + tmp_lst\n", + "\n", + "vocabulary = build_vocabulary(docs)\n", + "id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n", + "word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n", + "def seq2id(arr):\n", + " return torch.tensor([word2id[i] for i in arr])\n", + "def get_max_seq():\n", + " return len(list(set([len(i[0]) for i in lst])))\n", + "MAX_SEQ_LEN = get_max_seq()\n", + "def get_padded_x(data):\n", + " new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n", + " return new_data\n", + "def get_xy_vector(arr):\n", + " x = seq2id(arr[0])\n", + " y = seq2id([arr[1]])\n", + " return x, y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bb8929f7-85be-4740-be07-7dd6b4ed3086", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class NextWordModel(nn.Module):\n", + " \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n", + " def __init__(self, embedding_dim, hidden_dim, vocab_size):\n", + " super(NextWordModel, self).__init__()\n", + " self.hidden_dim = hidden_dim\n", + " self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n", + " self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n", + " self.linear = nn.Linear(hidden_dim, vocab_size)\n", + "\n", + " def forward(self, sentence):\n", + " embeds = self.word_embeddings(sentence)\n", + " lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n", + " x = self.linear(lstm_out.view(1, -1))\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "08b6f539-75a0-4a18-b462-165ca44387d3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on cuda:0\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " dev = \"cuda:0\"\n", + "else:\n", + " dev = \"cpu\"\n", + "print(f'Running on {dev}')\n", + "# set the model to be copied on GPU\n", + "device = torch.device(dev)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cf4f67b2-e0a4-4b57-abd9-4a2e40898511", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Finished\n" + ] + } + ], + "source": [ + "EMBEDDING_DIM = 10\n", + "NO_OF_EPOCHS = 300\n", + "HIDDEN_DIM = len(vocabulary)\n", + "model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n", + "loss_function = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", + "model.to(device)\n", + "for epoch in range(NO_OF_EPOCHS):\n", + " running_loss = 0.0\n", + " i = 0\n", + " for data in lst:\n", + " model.zero_grad()\n", + " x, y = get_xy_vector(data)\n", + "# convert to max seq length with padding\n", + " x = get_padded_x(x)\n", + "\n", + " x = x.to(device)\n", + " y = y.to(device)\n", + "\n", + " predicted = model(x)\n", + "\n", + " loss = loss_function(predicted, y)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += loss\n", + " i += 1\n", + " if i % 100 == 0:\n", + " #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n", + " running_loss = 0\n", + "print('Finished')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e84dfc-5b7a-47c2-96cd-304c5630a002", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type something here . . .\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is \n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what is\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is the \n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what is the\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is the population \n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what is the population\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is the population of \n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what is the population of\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is the population of the \n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " what is the population of the\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer :what is the population of the city \n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " print('Type something here . . .')\n", + " while True:\n", + " inp = input(\"\")\n", + " inp = inp.strip()\n", + " if inp == \"q\":\n", + " break\n", + "\n", + " tokens = nltk.word_tokenize(inp.lower())\n", + " x = seq2id(tokens)\n", + " x = get_padded_x(x)\n", + "\n", + " x = x.to(device)\n", + " predicted = model(x).to(device)\n", + "\n", + " predicted = predicted[0].cpu().numpy()\n", + "\n", + " print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "660e6502-7548-47df-a645-a1e2d33178f6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- GitLab