From 5a791efed1fc1537176a765647b5bd76a76f4e1c Mon Sep 17 00:00:00 2001
From: Moritz Kreinsen <moritz.kreinsen@uni-hamburg.de>
Date: Mon, 6 Feb 2023 11:18:21 +0000
Subject: [PATCH] Update

---
 Next-Token-Prediction.ipynb | 395 ++++++++++++++++++++++++++++++++++++
 1 file changed, 395 insertions(+)
 create mode 100644 Next-Token-Prediction.ipynb

diff --git a/Next-Token-Prediction.ipynb b/Next-Token-Prediction.ipynb
new file mode 100644
index 0000000..fcfd817
--- /dev/null
+++ b/Next-Token-Prediction.ipynb
@@ -0,0 +1,395 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "b39e9d1f-05b5-43b4-b50a-036ae88657cd",
+   "metadata": {},
+   "source": [
+    "# Next-Token-Prediction\n",
+    "This is based on the following blog posts: \n",
+    "* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "9c35813f-ccb3-42b3-a7e1-6069dece172f",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "import nltk\n",
+    "import pandas as pd\n",
+    "import torch\n",
+    "import numpy as np\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.optim as optim"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "b482fac0-b576-472e-ab5c-df951d0b2404",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[nltk_data] Downloading package punkt to /home/container/nltk_data...\n",
+      "[nltk_data]   Package punkt is already up-to-date!\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "True"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "nltk.download('punkt')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f6a376c4-bd61-44d0-badc-8c451394b627",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "dataset = \"\"\"\n",
+    "    Is Antwerp a city?,\n",
+    "    Is Antwerp a municipality?,\n",
+    "    Is Antwerp in Belgium?,\n",
+    "    What is Antwerp?,\n",
+    "    What is the population of the city of Antwerp?,\n",
+    "    Where is the city of Antwerp?,\n",
+    "    Why is Antwerp important to fashion?,\n",
+    "    Antwerp is to the east of what river?,\n",
+    "    How many municipalities does Antwerp have?,\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "c5831581-84e1-41d1-ae4f-b87e5bab365c",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "def get_all_possible_sequences(text):\n",
+    "    seq = []\n",
+    "    words = nltk.word_tokenize(text)\n",
+    "    total_words = len(words)\n",
+    "    for i in range(1, total_words):\n",
+    "        for j in range(1, len(words)-i+1):\n",
+    "            arr = words[j-1:j+i]\n",
+    "            seq.append((arr[:-1], arr[-1]))\n",
+    "    return seq\n",
+    "def build_vocabulary(docs):\n",
+    "    vocabulary = []\n",
+    "    for doc in docs:\n",
+    "        for w in nltk.word_tokenize(doc):\n",
+    "            if w not in vocabulary:\n",
+    "                vocabulary.append(w)\n",
+    "    vocabulary.append('UNK')\n",
+    "    return vocabulary"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "docs = []\n",
+    "for row in dataset.split(\",\"):\n",
+    "    docs.append(row.lower())\n",
+    "\n",
+    "lst = []\n",
+    "for doc in docs:\n",
+    "    tmp_lst = get_all_possible_sequences(doc)\n",
+    "    lst = lst + tmp_lst\n",
+    "\n",
+    "vocabulary = build_vocabulary(docs)\n",
+    "id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n",
+    "word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n",
+    "def seq2id(arr):\n",
+    "    return torch.tensor([word2id[i] for i in arr])\n",
+    "def get_max_seq():\n",
+    "    return len(list(set([len(i[0]) for i in lst])))\n",
+    "MAX_SEQ_LEN = get_max_seq()\n",
+    "def get_padded_x(data):\n",
+    "    new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n",
+    "    return new_data\n",
+    "def get_xy_vector(arr):\n",
+    "    x = seq2id(arr[0])\n",
+    "    y = seq2id([arr[1]])\n",
+    "    return x, y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "bb8929f7-85be-4740-be07-7dd6b4ed3086",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "class NextWordModel(nn.Module):\n",
+    "    \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n",
+    "    def __init__(self, embedding_dim, hidden_dim, vocab_size):\n",
+    "        super(NextWordModel, self).__init__()\n",
+    "        self.hidden_dim = hidden_dim\n",
+    "        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
+    "        self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n",
+    "        self.linear = nn.Linear(hidden_dim, vocab_size)\n",
+    "\n",
+    "    def forward(self, sentence):\n",
+    "        embeds = self.word_embeddings(sentence)\n",
+    "        lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n",
+    "        x = self.linear(lstm_out.view(1, -1))\n",
+    "        return x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "08b6f539-75a0-4a18-b462-165ca44387d3",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Running on cuda:0\n"
+     ]
+    }
+   ],
+   "source": [
+    "if torch.cuda.is_available():\n",
+    "  dev = \"cuda:0\"\n",
+    "else:\n",
+    "  dev = \"cpu\"\n",
+    "print(f'Running on {dev}')\n",
+    "# set the model to be copied on GPU\n",
+    "device = torch.device(dev)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "cf4f67b2-e0a4-4b57-abd9-4a2e40898511",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Finished\n"
+     ]
+    }
+   ],
+   "source": [
+    "EMBEDDING_DIM = 10\n",
+    "NO_OF_EPOCHS = 300\n",
+    "HIDDEN_DIM = len(vocabulary)\n",
+    "model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n",
+    "loss_function = nn.CrossEntropyLoss()\n",
+    "optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
+    "model.to(device)\n",
+    "for epoch in range(NO_OF_EPOCHS):\n",
+    "    running_loss = 0.0\n",
+    "    i = 0\n",
+    "    for data in lst:\n",
+    "        model.zero_grad()\n",
+    "        x, y = get_xy_vector(data)\n",
+    "# convert to max seq length with padding\n",
+    "        x = get_padded_x(x)\n",
+    "\n",
+    "        x = x.to(device)\n",
+    "        y = y.to(device)\n",
+    "\n",
+    "        predicted = model(x)\n",
+    "\n",
+    "        loss = loss_function(predicted, y)\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "\n",
+    "        running_loss += loss\n",
+    "        i += 1\n",
+    "        if i % 100 == 0:\n",
+    "            #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n",
+    "            running_loss = 0\n",
+    "print('Finished')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "84e84dfc-5b7a-47c2-96cd-304c5630a002",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Type something here . . .\n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is \n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what is\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is the \n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what is the\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is the population \n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what is the population\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is the population of \n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what is the population of\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is the population of the \n"
+     ]
+    },
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " what is the population of the\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Answer :what is the population of the city \n"
+     ]
+    }
+   ],
+   "source": [
+    "with torch.no_grad():\n",
+    "    print('Type something here . . .')\n",
+    "    while True:\n",
+    "        inp = input(\"\")\n",
+    "        inp = inp.strip()\n",
+    "        if inp == \"q\":\n",
+    "            break\n",
+    "\n",
+    "        tokens = nltk.word_tokenize(inp.lower())\n",
+    "        x = seq2id(tokens)\n",
+    "        x = get_padded_x(x)\n",
+    "\n",
+    "        x = x.to(device)\n",
+    "        predicted = model(x).to(device)\n",
+    "\n",
+    "        predicted = predicted[0].cpu().numpy()\n",
+    "\n",
+    "        print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "660e6502-7548-47df-a645-a1e2d33178f6",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
-- 
GitLab