{ "cells": [ { "cell_type": "markdown", "id": "0f901b18", "metadata": {}, "source": [ "# Next-Token-Prediction\n", "This is based on the following blog posts: \n", "* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671\n", "* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286" ] }, { "cell_type": "code", "execution_count": 1, "id": "b2471e25", "metadata": { "tags": [] }, "outputs": [], "source": [ "import nltk\n", "import pandas as pd\n", "import torch\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim" ] }, { "cell_type": "code", "execution_count": 2, "id": "c0febc54", "metadata": {}, "outputs": [], "source": [ "nltk.download('punkt')" ] }, { "cell_type": "code", "execution_count": 3, "id": "8a781aa1", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset = \"\"\"\n", " Is Antwerp a city?,\n", " Is Antwerp a municipality?,\n", " Is Antwerp in Belgium?,\n", " What is Antwerp?,\n", " What is the population of the city of Antwerp?,\n", " Where is the city of Antwerp?,\n", " Why is Antwerp important to fashion?,\n", " Antwerp is to the east of what river?,\n", " How many municipalities does Antwerp have?,\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "39255d82", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_all_possible_sequences(text):\n", " seq = []\n", " words = nltk.word_tokenize(text)\n", " total_words = len(words)\n", " for i in range(1, total_words):\n", " for j in range(1, len(words)-i+1):\n", " arr = words[j-1:j+i]\n", " seq.append((arr[:-1], arr[-1]))\n", " return seq\n", "def build_vocabulary(docs):\n", " vocabulary = []\n", " for doc in docs:\n", " for w in nltk.word_tokenize(doc):\n", " if w not in vocabulary:\n", " vocabulary.append(w)\n", " vocabulary.append('UNK')\n", " return vocabulary" ] }, { "cell_type": "code", "execution_count": 5, "id": "edb54d0d", "metadata": { "tags": [] }, "outputs": [], "source": [ "docs = []\n", "for row in dataset.split(\",\"):\n", " docs.append(row.lower())\n", "\n", "lst = []\n", "for doc in docs:\n", " tmp_lst = get_all_possible_sequences(doc)\n", " lst = lst + tmp_lst\n", "\n", "vocabulary = build_vocabulary(docs)\n", "id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n", "word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n", "def seq2id(arr):\n", " return torch.tensor([word2id[i] for i in arr])\n", "def get_max_seq():\n", " return len(list(set([len(i[0]) for i in lst])))\n", "MAX_SEQ_LEN = get_max_seq()\n", "def get_padded_x(data):\n", " new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n", " return new_data\n", "def get_xy_vector(arr):\n", " x = seq2id(arr[0])\n", " y = seq2id([arr[1]])\n", " return x, y" ] }, { "cell_type": "code", "execution_count": 6, "id": "431fd558", "metadata": { "tags": [] }, "outputs": [], "source": [ "class NextWordModel(nn.Module):\n", " \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n", " def __init__(self, embedding_dim, hidden_dim, vocab_size):\n", " super(NextWordModel, self).__init__()\n", " self.hidden_dim = hidden_dim\n", " self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n", " self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n", " self.linear = nn.Linear(hidden_dim, vocab_size)\n", "\n", " def forward(self, sentence):\n", " embeds = self.word_embeddings(sentence)\n", " lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n", " x = self.linear(lstm_out.view(1, -1))\n", " return x" ] }, { "cell_type": "code", "execution_count": 7, "id": "d891422e", "metadata": { "tags": [] }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " dev = \"cuda:0\"\n", "else:\n", " dev = \"cpu\"\n", "print(f'Running on {dev}')\n", "# set the model to be copied on GPU\n", "device = torch.device(dev)" ] }, { "cell_type": "code", "execution_count": 8, "id": "7c231c92", "metadata": { "tags": [] }, "outputs": [], "source": [ "EMBEDDING_DIM = 10\n", "NO_OF_EPOCHS = 300\n", "HIDDEN_DIM = len(vocabulary)\n", "model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n", "loss_function = nn.CrossEntropyLoss()\n", "optimizer = optim.SGD(model.parameters(), lr=0.1)\n", "model.to(device)\n", "for epoch in range(NO_OF_EPOCHS):\n", " running_loss = 0.0\n", " i = 0\n", " for data in lst:\n", " model.zero_grad()\n", " x, y = get_xy_vector(data)\n", "# convert to max seq length with padding\n", " x = get_padded_x(x)\n", "\n", " x = x.to(device)\n", " y = y.to(device)\n", "\n", " predicted = model(x)\n", "\n", " loss = loss_function(predicted, y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss\n", " i += 1\n", " if i % 100 == 0:\n", " #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n", " running_loss = 0\n", "print('Finished')" ] }, { "cell_type": "code", "execution_count": null, "id": "9d0ff01d", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " print('Type something here . . .')\n", " while True:\n", " inp = input(\"\")\n", " inp = inp.strip()\n", " if inp == \"q\":\n", " break\n", "\n", " tokens = nltk.word_tokenize(inp.lower())\n", " x = seq2id(tokens)\n", " x = get_padded_x(x)\n", "\n", " x = x.to(device)\n", " predicted = model(x).to(device)\n", "\n", " predicted = predicted[0].cpu().numpy()\n", "\n", " print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')" ] }, { "cell_type": "code", "execution_count": null, "id": "c1e828f0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }