Skip to content
Snippets Groups Projects
Next-Token-Prediction.ipynb 7.27 KiB
Newer Older
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
{
 "cells": [
  {
   "cell_type": "markdown",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "0f901b18",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {},
   "source": [
    "# Next-Token-Prediction\n",
    "This is based on the following blog posts: \n",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
    "* Predicting Next Word — NLP & Deep Learning: https://medium.com/@vijay2340025/predicting-next-word-nlp-deep-learning-85010d966671\n",
    "* How ChatGPT Works: The Model Behind The Bot: https://towardsdatascience.com/how-chatgpt-works-the-models-behind-the-bot-1ce5fca96286"
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "b2471e25",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "c0febc54",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {},
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "outputs": [],
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "source": [
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "8a781aa1",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset = \"\"\"\n",
    "    Is Antwerp a city?,\n",
    "    Is Antwerp a municipality?,\n",
    "    Is Antwerp in Belgium?,\n",
    "    What is Antwerp?,\n",
    "    What is the population of the city of Antwerp?,\n",
    "    Where is the city of Antwerp?,\n",
    "    Why is Antwerp important to fashion?,\n",
    "    Antwerp is to the east of what river?,\n",
    "    How many municipalities does Antwerp have?,\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "39255d82",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_all_possible_sequences(text):\n",
    "    seq = []\n",
    "    words = nltk.word_tokenize(text)\n",
    "    total_words = len(words)\n",
    "    for i in range(1, total_words):\n",
    "        for j in range(1, len(words)-i+1):\n",
    "            arr = words[j-1:j+i]\n",
    "            seq.append((arr[:-1], arr[-1]))\n",
    "    return seq\n",
    "def build_vocabulary(docs):\n",
    "    vocabulary = []\n",
    "    for doc in docs:\n",
    "        for w in nltk.word_tokenize(doc):\n",
    "            if w not in vocabulary:\n",
    "                vocabulary.append(w)\n",
    "    vocabulary.append('UNK')\n",
    "    return vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "edb54d0d",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "docs = []\n",
    "for row in dataset.split(\",\"):\n",
    "    docs.append(row.lower())\n",
    "\n",
    "lst = []\n",
    "for doc in docs:\n",
    "    tmp_lst = get_all_possible_sequences(doc)\n",
    "    lst = lst + tmp_lst\n",
    "\n",
    "vocabulary = build_vocabulary(docs)\n",
    "id2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n",
    "word2id = {w: idx for (idx, w) in enumerate(vocabulary)}\n",
    "def seq2id(arr):\n",
    "    return torch.tensor([word2id[i] for i in arr])\n",
    "def get_max_seq():\n",
    "    return len(list(set([len(i[0]) for i in lst])))\n",
    "MAX_SEQ_LEN = get_max_seq()\n",
    "def get_padded_x(data):\n",
    "    new_data = F.pad(input=data.view(1,-1), pad=(0, MAX_SEQ_LEN-data.shape[0], 0, 0), mode='constant', value=word2id['UNK'])\n",
    "    return new_data\n",
    "def get_xy_vector(arr):\n",
    "    x = seq2id(arr[0])\n",
    "    y = seq2id([arr[1]])\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "431fd558",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class NextWordModel(nn.Module):\n",
    "    \"\"\" Prediction of Next word based on the MAX_SEQ_LEN Sequence \"\"\"\n",
    "    def __init__(self, embedding_dim, hidden_dim, vocab_size):\n",
    "        super(NextWordModel, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = nn.GRU(embedding_dim * MAX_SEQ_LEN, hidden_dim)\n",
    "        self.linear = nn.Linear(hidden_dim, vocab_size)\n",
    "\n",
    "    def forward(self, sentence):\n",
    "        embeds = self.word_embeddings(sentence)\n",
    "        lstm_out, _ = self.gru(embeds.view(1, 1, -1))\n",
    "        x = self.linear(lstm_out.view(1, -1))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "d891422e",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "outputs": [],
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "source": [
    "if torch.cuda.is_available():\n",
    "  dev = \"cuda:0\"\n",
    "else:\n",
    "  dev = \"cpu\"\n",
    "print(f'Running on {dev}')\n",
    "# set the model to be copied on GPU\n",
    "device = torch.device(dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "7c231c92",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {
    "tags": []
   },
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "outputs": [],
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "source": [
    "EMBEDDING_DIM = 10\n",
    "NO_OF_EPOCHS = 300\n",
    "HIDDEN_DIM = len(vocabulary)\n",
    "model = NextWordModel(EMBEDDING_DIM, HIDDEN_DIM, len(vocabulary))\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
    "model.to(device)\n",
    "for epoch in range(NO_OF_EPOCHS):\n",
    "    running_loss = 0.0\n",
    "    i = 0\n",
    "    for data in lst:\n",
    "        model.zero_grad()\n",
    "        x, y = get_xy_vector(data)\n",
    "# convert to max seq length with padding\n",
    "        x = get_padded_x(x)\n",
    "\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        predicted = model(x)\n",
    "\n",
    "        loss = loss_function(predicted, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss\n",
    "        i += 1\n",
    "        if i % 100 == 0:\n",
    "            #print(f'Loss at iteration {i} and epoch {epoch} is {running_loss / 100}')\n",
    "            running_loss = 0\n",
    "print('Finished')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "9d0ff01d",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {},
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "outputs": [],
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "source": [
    "with torch.no_grad():\n",
    "    print('Type something here . . .')\n",
    "    while True:\n",
    "        inp = input(\"\")\n",
    "        inp = inp.strip()\n",
    "        if inp == \"q\":\n",
    "            break\n",
    "\n",
    "        tokens = nltk.word_tokenize(inp.lower())\n",
    "        x = seq2id(tokens)\n",
    "        x = get_padded_x(x)\n",
    "\n",
    "        x = x.to(device)\n",
    "        predicted = model(x).to(device)\n",
    "\n",
    "        predicted = predicted[0].cpu().numpy()\n",
    "\n",
    "        print(f'Answer: {inp} {id2word[np.argmax(predicted)]} ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "id": "c1e828f0",
Kreinsen, Moritz's avatar
Kreinsen, Moritz committed
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}