{ "cells": [ { "cell_type": "markdown", "id": "b39e9d1f-05b5-43b4-b50a-036ae88657cd", "metadata": {}, "source": [ "# Next-Token-Prediction\n", "This is based on the following blog posts: \n", "* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671" ] }, { "cell_type": "code", "execution_count": 1, "id": "9c35813f-ccb3-42b3-a7e1-6069dece172f", "metadata": { "tags": [] }, "outputs": [], "source": [ "import nltk\n", "import pandas as pd\n", "import torch\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim" ] }, { "cell_type": "code", "execution_count": 2, "id": "b482fac0-b576-472e-ab5c-df951d0b2404", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package punkt to /home/container/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nltk.download('punkt')" ] }, { "cell_type": "code", "execution_count": 3, "id": "f6a376c4-bd61-44d0-badc-8c451394b627", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset = \"\"\"\n", " Is Antwerp a city?,\n", " Is Antwerp a municipality?,\n", " Is Antwerp in Belgium?,\n", " What is Antwerp?,\n", " What is the population of the city of Antwerp?,\n", " Where is the city of Antwerp?,\n", " Why is Antwerp important to fashion?,\n", " Antwerp is to the east of what river?,\n", " How many municipalities does Antwerp have?,\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "c5831581-84e1-41d1-ae4f-b87e5bab365c", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_all_possible_sequences(text):\n", " seq = []\n", " words = nltk.word_tokenize(text)\n", " total_words = len(words)\n", " for i in range(1, total_words):\n", " for j in range(1, len(words)-i+1):\n", " arr = words[j-1:j+i]\n", " seq.append((arr[:-1], arr[-1]))\n", " return seq\n", "def build_vocabulary(docs):\n", " vocabulary = []\n", " for doc in docs:\n", " for w in nltk.word_tokenize(doc):\n", " if w not in vocabulary:\n", " vocabulary.append(w)\n", " vocabulary.append('UNK')\n", " return vocabulary" ] }, { "cell_type": "code", "execution_count": 5, "id": "2171ca5a-ca9a-4e33-86d0-8fa4b612f8a0", "metadata": { "tags": [] }, "outputs": [], "source": [ "docs = []\n", "for row in dataset.split(\",\"):\n", " docs.append(row.lower())\n", "\n", "lst = []\n", "for doc in docs:\n", " tmp_lst = get_all_possible_sequences(doc)\n", " lst = lst + tmp_lst\n", "\n", "vocabulary = build_vocabulary(docs)\n", "id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n", "word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n", "def seq2id(arr):\n", " return torch.tensor([word2id[i] for i in arr])\n", "def get_max_seq():\n", " return len(list(set([len(i[0]) for i in lst])))\n", "MAX_SEQ_LEN = get_max_seq()\n", "def get_padded_x(data):\n", " new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n", " return new_data\n", "def get_xy_vector(arr):\n", " x = seq2id(arr[0])\n", " y = seq2id([arr[1]])\n", " return x, y" ] }, { "cell_type": "code", "execution_count": 6, "id": "bb8929f7-85be-4740-be07-7dd6b4ed3086", "metadata": { "tags": [] }, "outputs": [], "source": [ "class NextWordModel(nn.Module):\n", " \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size):\n", " super(NextWordModel, self).__init__()\n", " self.hidden_dim = hidden_dim\n", " self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n", " self.linear = nn.Linear(hidden_dim, vocab_size)\n", "\n", " def forward(self, sentence):\n", " embeds = self.word_embeddings(sentence)\n", " lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n", " x = self.linear(lstm_out.view(1, -1))\n", " return x" ] }, { "cell_type": "code", "execution_count": 7, "id": "08b6f539-75a0-4a18-b462-165ca44387d3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on cuda:0\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " dev = \"cuda:0\"\n", "else:\n", " dev = \"cpu\"\n", "print(f'Running on {dev}')\n", "# set the model to be copied on GPU\n", "device = torch.device(dev)" ] }, { "cell_type": "code", "execution_count": 8, "id": "cf4f67b2-e0a4-4b57-abd9-4a2e40898511", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Finished\n" ] } ], "source": [ "EMBEDDING_DIM = 10\n", "NO_OF_EPOCHS = 300\n", "HIDDEN_DIM = len(vocabulary)\n", "model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n", "loss_function = nn.CrossEntropyLoss()\n", "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", "model.to(device)\n", "for epoch in range(NO_OF_EPOCHS):\n", " running_loss = 0.0\n", " i = 0\n", " for data in lst:\n", " model.zero_grad()\n", " x, y = get_xy_vector(data)\n", "# convert to max seq length with padding\n", " x = get_padded_x(x)\n", "\n", " x = x.to(device)\n", " y = y.to(device)\n", "\n", " predicted = model(x)\n", "\n", " loss = loss_function(predicted, y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss\n", " i += 1\n", " if i % 100 == 0:\n", " #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n", " running_loss = 0\n", "print('Finished')" ] }, { "cell_type": "code", "execution_count": null, "id": "84e84dfc-5b7a-47c2-96cd-304c5630a002", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Type something here . . .\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is \n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what is\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is the \n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what is the\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is the population \n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what is the population\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is the population of \n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what is the population of\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is the population of the \n" ] }, { "name": "stdin", "output_type": "stream", "text": [ " what is the population of the\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer :what is the population of the city \n" ] } ], "source": [ "with torch.no_grad():\n", " print('Type something here . . .')\n", " while True:\n", " inp = input(\"\")\n", " inp = inp.strip()\n", " if inp == \"q\":\n", " break\n", "\n", " tokens = nltk.word_tokenize(inp.lower())\n", " x = seq2id(tokens)\n", " x = get_padded_x(x)\n", "\n", " x = x.to(device)\n", " predicted = model(x).to(device)\n", "\n", " predicted = predicted[0].cpu().numpy()\n", "\n", " print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')" ] }, { "cell_type": "code", "execution_count": null, "id": "660e6502-7548-47df-a645-a1e2d33178f6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }