diff --git a/playground/GPT DEV_1_become_one_with_data copy.ipynb b/playground/GPT DEV_1_become_one_with_data copy.ipynb new file mode 100644 index 0000000..77d86e4 --- /dev/null +++ b/playground/GPT DEV_1_become_one_with_data copy.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "- [A Recipe for Training Neural Networks\n", + "](https://karpathy.github.io/2019/04/25/recipe/)\n", + "- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)\n", + "- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)\n", + "- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Become one with the data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "length of dataset in characters: 1115394\n", + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You\n" + ] + } + ], + "source": [ + "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", + "\n", + "with open('input.txt', 'r', encoding='utf-8') as f:\n", + " text = f.read()\n", + "\n", + "print(\"length of dataset in characters: \", len(text))\n", + "print(text[:100])\n", + "train_data = text[:int(len(text)*0.9)]\n", + "val_data = text[int(len(text)*0.9):]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['First', ' ', 'Citizen:', '\\n', 'Before', ' ', 'we', ' ', 'proceed', ' ', 'any', ' ', 'further,', ' ', 'hear', ' ', 'me', ' ', 'speak.']\n", + "[(' ', 169892), ('\\n', 40000), ('', 7242), ('the', 5437), ('I', 4403)]\n", + "[('open;', 1), ('standing,', 1), ('moving,', 1), ('sleep--die,', 1), (\"wink'st\", 1)]\n", + "splitted 419785 unique_word 25673\n" + ] + } + ], + "source": [ + "import re\n", + "\n", + "def split_string(input_string):\n", + " # 正規表現で改行(\\n)やスペース( )で区切り、それらも結果に含める\n", + " split_list = re.split(r'(\\s)', input_string)\n", + " return split_list\n", + "\n", + "first_period_index = text.index('.')\n", + "print(split_string(text[:first_period_index+1]))\n", + "unique_words = list(set(split_string(text)))\n", + "\n", + "word_count_dict = {}\n", + "for word in split_string(text):\n", + " if word in word_count_dict:\n", + " word_count_dict[word] += 1\n", + " else:\n", + " word_count_dict[word] = 1\n", + "# 多い順に並べ替え\n", + "word_count_dict = dict(sorted(word_count_dict.items(), key=lambda x: -x[1]))\n", + "# 上位・下位5件を表示\n", + "print(list(word_count_dict.items())[:5])\n", + "print(list(word_count_dict.items())[-5:])\n", + "print('splitted', len(split_string(text)), 'unique_word', len(unique_words))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13]\n", + "['First', ' Citizen', ':', '\\n', 'Before', ' we', ' proceed', ' any', ' further', ',', ' hear', ' me', ' speak', '.']\n", + "[(198, {'count': 39996, 'token_id': '\\n'}), (11, {'count': 19777, 'token_id': ','}), (25, {'count': 10291, 'token_id': ':'}), (13, {'count': 7811, 'token_id': '.'}), (262, {'count': 5370, 'token_id': ' the'})]\n", + "[(16558, {'count': 1, 'token_id': ' sphere'}), (31960, {'count': 1, 'token_id': ' Wond'}), (22194, {'count': 1, 'token_id': ' possesses'}), (29708, {'count': 1, 'token_id': ' eyel'}), (30757, {'count': 1, 'token_id': 'stroke'})]\n", + "splitted 338025 unique_token 11706 vocab_size 50257\n" + ] + } + ], + "source": [ + "import tiktoken\n", + "enc = tiktoken.get_encoding(\"gpt2\")\n", + "encoded_ids = enc.encode(text[:first_period_index+1])\n", + "decoded_text = [enc.decode([encoded_id]) for encoded_id in encoded_ids]\n", + "print(encoded_ids)\n", + "print(decoded_text)\n", + "\n", + "\n", + "unique_tokens = list(set(enc.encode(text)))\n", + "\n", + "token_count_dict = {}\n", + "for token in enc.encode(text):\n", + " if token in token_count_dict:\n", + " token_count_dict[token]['count'] += 1\n", + " else:\n", + " token_count_dict[token] = {'count': 1, 'token_id': enc.decode([token])}\n", + "# 多い順に並べ替え\n", + "token_count_dict = dict(sorted(token_count_dict.items(), key=lambda x: -x[1]['count']))\n", + "# 上位・下位5件を表示\n", + "print(list(token_count_dict.items())[:5])\n", + "print(list(token_count_dict.items())[-5:])\n", + "print('splitted', len(enc.encode(text)), 'unique_token', len(unique_tokens), 'vocab_size', enc.n_vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params of unigram 50257\n", + "2525766049 126937424324593\n" + ] + } + ], + "source": [ + "from ngram import Ngram\n", + "vocab = list(range(enc.n_vocab))\n", + "unigram = Ngram(1, vocab)\n", + "tokens = enc.encode(text)\n", + "unigram.train(tokens)\n", + "print('params of unigram', len(unigram.ngram)) \n", + "\n", + "\n", + "print(enc.n_vocab ** 2, enc.n_vocab ** 3)\n", + "# bigram = Ngram(2, vocab)\n", + "# bigram.train(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[('\\n', 39997), (',', 19778), (':', 10292), ('.', 7812), (' the', 5371)]\n", + "[('ominated', 1), (' regress', 1), (' Collider', 1), (' informants', 1), ('<|endoftext|>', 1)]\n" + ] + } + ], + "source": [ + "# 上位・下位5件を表示\n", + "unigram_info = unigram.ngram\n", + "unigram_info = dict(sorted(unigram_info.items(), key=lambda x: -x[1]))\n", + "top_unigram = list(unigram_info.items())[:5]\n", + "bottom_unigram = list(unigram_info.items())[-5:]\n", + "print([(enc.decode([token[0]]), count) for token, count in top_unigram])\n", + "print([(enc.decode([token[0]]), count) for token, count in bottom_unigram])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "50257" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "enc.n_vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input\n", + "torch.Size([4, 8])\n", + "tensor([[ 198, 30313, 262, 22397, 282, 290, 884, 3790],\n", + " [ 4151, 438, 198, 10418, 329, 511, 11989, 11],\n", + " [ 3355, 322, 12105, 287, 3426, 6729, 198, 3886],\n", + " [ 290, 15581, 8636, 13, 198, 198, 35510, 4221]])\n", + "[['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such', ' officers'], [' eye', '--', '\\n', 'Men', ' for', ' their', ' sons', ','], [' wall', 'ow', ' naked', ' in', ' December', ' snow', '\\n', 'By'], [' and', ' noble', ' estimate', '.', '\\n', '\\n', 'NOR', 'TH']]\n", + "target\n", + "torch.Size([4, 8])\n", + "tensor([[30313, 262, 22397, 282, 290, 884, 3790, 198],\n", + " [ 438, 198, 10418, 329, 511, 11989, 11, 17743],\n", + " [ 322, 12105, 287, 3426, 6729, 198, 3886, 3612],\n", + " [15581, 8636, 13, 198, 198, 35510, 4221, 5883]])\n", + "[['Except', ' the', ' marsh', 'al', ' and', ' such', ' officers', '\\n'], ['--', '\\n', 'Men', ' for', ' their', ' sons', ',', ' wives'], ['ow', ' naked', ' in', ' December', ' snow', '\\n', 'By', ' thinking'], [' noble', ' estimate', '.', '\\n', '\\n', 'NOR', 'TH', 'UM']]\n", + "input: ['\\n'] target: 'Except'\n", + "input: ['\\n', 'Except'] target: ' the'\n", + "input: ['\\n', 'Except', ' the'] target: ' marsh'\n", + "input: ['\\n', 'Except', ' the', ' marsh'] target: 'al'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al'] target: ' and'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and'] target: ' such'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such'] target: ' officers'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such', ' officers'] target: '\\n'\n" + ] + } + ], + "source": [ + "import torch\n", + "seed = 1337\n", + "torch.manual_seed(seed) \n", + "batch_size = 4\n", + "context_length = 8\n", + "data = torch.tensor(enc.encode(text), dtype=torch.long)\n", + "n = int(0.9*len(data)) # first 90% will be train, rest val\n", + "train_data = data[:n]\n", + "val_data = data[n:]\n", + "\n", + "\n", + "def get_batch(split):\n", + " data = train_data if split == 'train' else val_data\n", + " index = torch.randint(len(data) - context_length, (batch_size,))\n", + " x = torch.stack([data[i:i+context_length] for i in index])\n", + " y = torch.stack([data[i+1:i+1+context_length] for i in index])\n", + " return x, y\n", + "\n", + "\n", + "x, y = get_batch('train')\n", + "print('input')\n", + "print(x.shape)\n", + "print(x)\n", + "print([[enc.decode([token])for token in sequence] for sequence in x])\n", + "print('target')\n", + "print(y.shape)\n", + "print(y)\n", + "print([[enc.decode([token])for token in sequence] for sequence in y])\n", + "\n", + "for t in range(context_length):\n", + " context = x[0, :t+1]\n", + " target = y[0, t]\n", + " print('input: ', [enc.decode([token]) for token in context], 'target: ', repr(enc.decode([target])))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playground/GPT DEV_1_become_one_with_data.ipynb b/playground/GPT DEV_1_become_one_with_data.ipynb new file mode 100644 index 0000000..77d86e4 --- /dev/null +++ b/playground/GPT DEV_1_become_one_with_data.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "- [A Recipe for Training Neural Networks\n", + "](https://karpathy.github.io/2019/04/25/recipe/)\n", + "- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)\n", + "- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)\n", + "- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Become one with the data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "length of dataset in characters: 1115394\n", + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You\n" + ] + } + ], + "source": [ + "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", + "\n", + "with open('input.txt', 'r', encoding='utf-8') as f:\n", + " text = f.read()\n", + "\n", + "print(\"length of dataset in characters: \", len(text))\n", + "print(text[:100])\n", + "train_data = text[:int(len(text)*0.9)]\n", + "val_data = text[int(len(text)*0.9):]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['First', ' ', 'Citizen:', '\\n', 'Before', ' ', 'we', ' ', 'proceed', ' ', 'any', ' ', 'further,', ' ', 'hear', ' ', 'me', ' ', 'speak.']\n", + "[(' ', 169892), ('\\n', 40000), ('', 7242), ('the', 5437), ('I', 4403)]\n", + "[('open;', 1), ('standing,', 1), ('moving,', 1), ('sleep--die,', 1), (\"wink'st\", 1)]\n", + "splitted 419785 unique_word 25673\n" + ] + } + ], + "source": [ + "import re\n", + "\n", + "def split_string(input_string):\n", + " # 正規表現で改行(\\n)やスペース( )で区切り、それらも結果に含める\n", + " split_list = re.split(r'(\\s)', input_string)\n", + " return split_list\n", + "\n", + "first_period_index = text.index('.')\n", + "print(split_string(text[:first_period_index+1]))\n", + "unique_words = list(set(split_string(text)))\n", + "\n", + "word_count_dict = {}\n", + "for word in split_string(text):\n", + " if word in word_count_dict:\n", + " word_count_dict[word] += 1\n", + " else:\n", + " word_count_dict[word] = 1\n", + "# 多い順に並べ替え\n", + "word_count_dict = dict(sorted(word_count_dict.items(), key=lambda x: -x[1]))\n", + "# 上位・下位5件を表示\n", + "print(list(word_count_dict.items())[:5])\n", + "print(list(word_count_dict.items())[-5:])\n", + "print('splitted', len(split_string(text)), 'unique_word', len(unique_words))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13]\n", + "['First', ' Citizen', ':', '\\n', 'Before', ' we', ' proceed', ' any', ' further', ',', ' hear', ' me', ' speak', '.']\n", + "[(198, {'count': 39996, 'token_id': '\\n'}), (11, {'count': 19777, 'token_id': ','}), (25, {'count': 10291, 'token_id': ':'}), (13, {'count': 7811, 'token_id': '.'}), (262, {'count': 5370, 'token_id': ' the'})]\n", + "[(16558, {'count': 1, 'token_id': ' sphere'}), (31960, {'count': 1, 'token_id': ' Wond'}), (22194, {'count': 1, 'token_id': ' possesses'}), (29708, {'count': 1, 'token_id': ' eyel'}), (30757, {'count': 1, 'token_id': 'stroke'})]\n", + "splitted 338025 unique_token 11706 vocab_size 50257\n" + ] + } + ], + "source": [ + "import tiktoken\n", + "enc = tiktoken.get_encoding(\"gpt2\")\n", + "encoded_ids = enc.encode(text[:first_period_index+1])\n", + "decoded_text = [enc.decode([encoded_id]) for encoded_id in encoded_ids]\n", + "print(encoded_ids)\n", + "print(decoded_text)\n", + "\n", + "\n", + "unique_tokens = list(set(enc.encode(text)))\n", + "\n", + "token_count_dict = {}\n", + "for token in enc.encode(text):\n", + " if token in token_count_dict:\n", + " token_count_dict[token]['count'] += 1\n", + " else:\n", + " token_count_dict[token] = {'count': 1, 'token_id': enc.decode([token])}\n", + "# 多い順に並べ替え\n", + "token_count_dict = dict(sorted(token_count_dict.items(), key=lambda x: -x[1]['count']))\n", + "# 上位・下位5件を表示\n", + "print(list(token_count_dict.items())[:5])\n", + "print(list(token_count_dict.items())[-5:])\n", + "print('splitted', len(enc.encode(text)), 'unique_token', len(unique_tokens), 'vocab_size', enc.n_vocab)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "params of unigram 50257\n", + "2525766049 126937424324593\n" + ] + } + ], + "source": [ + "from ngram import Ngram\n", + "vocab = list(range(enc.n_vocab))\n", + "unigram = Ngram(1, vocab)\n", + "tokens = enc.encode(text)\n", + "unigram.train(tokens)\n", + "print('params of unigram', len(unigram.ngram)) \n", + "\n", + "\n", + "print(enc.n_vocab ** 2, enc.n_vocab ** 3)\n", + "# bigram = Ngram(2, vocab)\n", + "# bigram.train(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[('\\n', 39997), (',', 19778), (':', 10292), ('.', 7812), (' the', 5371)]\n", + "[('ominated', 1), (' regress', 1), (' Collider', 1), (' informants', 1), ('<|endoftext|>', 1)]\n" + ] + } + ], + "source": [ + "# 上位・下位5件を表示\n", + "unigram_info = unigram.ngram\n", + "unigram_info = dict(sorted(unigram_info.items(), key=lambda x: -x[1]))\n", + "top_unigram = list(unigram_info.items())[:5]\n", + "bottom_unigram = list(unigram_info.items())[-5:]\n", + "print([(enc.decode([token[0]]), count) for token, count in top_unigram])\n", + "print([(enc.decode([token[0]]), count) for token, count in bottom_unigram])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "50257" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "enc.n_vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input\n", + "torch.Size([4, 8])\n", + "tensor([[ 198, 30313, 262, 22397, 282, 290, 884, 3790],\n", + " [ 4151, 438, 198, 10418, 329, 511, 11989, 11],\n", + " [ 3355, 322, 12105, 287, 3426, 6729, 198, 3886],\n", + " [ 290, 15581, 8636, 13, 198, 198, 35510, 4221]])\n", + "[['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such', ' officers'], [' eye', '--', '\\n', 'Men', ' for', ' their', ' sons', ','], [' wall', 'ow', ' naked', ' in', ' December', ' snow', '\\n', 'By'], [' and', ' noble', ' estimate', '.', '\\n', '\\n', 'NOR', 'TH']]\n", + "target\n", + "torch.Size([4, 8])\n", + "tensor([[30313, 262, 22397, 282, 290, 884, 3790, 198],\n", + " [ 438, 198, 10418, 329, 511, 11989, 11, 17743],\n", + " [ 322, 12105, 287, 3426, 6729, 198, 3886, 3612],\n", + " [15581, 8636, 13, 198, 198, 35510, 4221, 5883]])\n", + "[['Except', ' the', ' marsh', 'al', ' and', ' such', ' officers', '\\n'], ['--', '\\n', 'Men', ' for', ' their', ' sons', ',', ' wives'], ['ow', ' naked', ' in', ' December', ' snow', '\\n', 'By', ' thinking'], [' noble', ' estimate', '.', '\\n', '\\n', 'NOR', 'TH', 'UM']]\n", + "input: ['\\n'] target: 'Except'\n", + "input: ['\\n', 'Except'] target: ' the'\n", + "input: ['\\n', 'Except', ' the'] target: ' marsh'\n", + "input: ['\\n', 'Except', ' the', ' marsh'] target: 'al'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al'] target: ' and'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and'] target: ' such'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such'] target: ' officers'\n", + "input: ['\\n', 'Except', ' the', ' marsh', 'al', ' and', ' such', ' officers'] target: '\\n'\n" + ] + } + ], + "source": [ + "import torch\n", + "seed = 1337\n", + "torch.manual_seed(seed) \n", + "batch_size = 4\n", + "context_length = 8\n", + "data = torch.tensor(enc.encode(text), dtype=torch.long)\n", + "n = int(0.9*len(data)) # first 90% will be train, rest val\n", + "train_data = data[:n]\n", + "val_data = data[n:]\n", + "\n", + "\n", + "def get_batch(split):\n", + " data = train_data if split == 'train' else val_data\n", + " index = torch.randint(len(data) - context_length, (batch_size,))\n", + " x = torch.stack([data[i:i+context_length] for i in index])\n", + " y = torch.stack([data[i+1:i+1+context_length] for i in index])\n", + " return x, y\n", + "\n", + "\n", + "x, y = get_batch('train')\n", + "print('input')\n", + "print(x.shape)\n", + "print(x)\n", + "print([[enc.decode([token])for token in sequence] for sequence in x])\n", + "print('target')\n", + "print(y.shape)\n", + "print(y)\n", + "print([[enc.decode([token])for token in sequence] for sequence in y])\n", + "\n", + "for t in range(context_length):\n", + " context = x[0, :t+1]\n", + " target = y[0, t]\n", + " print('input: ', [enc.decode([token]) for token in context], 'target: ', repr(enc.decode([target])))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playground/GPT DEV_2_e2epipeline_baseline.ipynb b/playground/GPT DEV_2_e2epipeline_baseline.ipynb new file mode 100644 index 0000000..e42834d --- /dev/null +++ b/playground/GPT DEV_2_e2epipeline_baseline.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "- [A Recipe for Training Neural Networks\n", + "](https://karpathy.github.io/2019/04/25/recipe/)\n", + "- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)\n", + "- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)\n", + "- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up the end-to-end training/evaluation skeleton + get dumb baselines" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn import functional as F\n", + "\n", + "torch.manual_seed(1337)\n", + "\n", + "class BigramLanguageModel(nn.Module):\n", + " def __init__(self, vocab_size):\n", + " super().__init__()\n", + " # self.bigram_table = nn.Embedding(vocab_size, vocab_size)\n", + " self.token_embedding_table = nn.Embedding(vocab_size, 16)\n", + " self.linear = nn.Linear(16, vocab_size)\n", + " print('number of parameters:', sum(p.numel() for p in self.parameters()))\n", + " \n", + " def forward(self, token_indexes):\n", + " # token_index: (batch_size, sequence_length)\n", + " # logits = self.bigram_table(token_indexes)\n", + "\n", + " embedding = self.token_embedding_table(token_indexes)\n", + " logits = self.linear(embedding)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " return logits\n", + "\n", + " def loss_per_token(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length),\n", + " reduction='none'\n", + " )\n", + " # loss: (batch_size*sequence_length)\n", + " return loss.view(batch_size, sequence_length)\n", + " \n", + " def loss(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length)\n", + " )\n", + " # loss: scalar\n", + " return loss\n", + " \n", + " def generate(self, token_indexes, max_new_tokens):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.shape\n", + " for _ in range(max_new_tokens):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " next_token_logits = logits[:, -1, :]\n", + " # next_token_logits: (batch_size, vocab_size)\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " # next_token_probs: (batch_size, vocab_size)\n", + " next_token = torch.multinomial(next_token_probs, num_samples=1)\n", + " # next_token: (batch_size, 1)\n", + " token_indexes = torch.cat([token_indexes, next_token], dim=1)\n", + " # token_indexes: (batch_size, sequence_length+1)\n", + " return token_indexes\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def rand_int_test(cls, low, high, shape, kwargs):\n", + " layer = cls(**kwargs).cuda()\n", + " random_input = torch.randint(low, high, shape).cuda()\n", + " print('input shape:', random_input.shape)\n", + " output = layer(random_input)\n", + " print('output shape:', output.shape)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 8448\n", + "input shape: torch.Size([4, 1024])\n", + "output shape: torch.Size([4, 1024, 256])\n" + ] + } + ], + "source": [ + "test_cls = BigramLanguageModel\n", + "batch_size = 4\n", + "context_length = 1024\n", + "vocab_size = 256\n", + "\n", + "kwargs = {'vocab_size': vocab_size}\n", + "output = rand_int_test(test_cls, 0, vocab_size, (batch_size, context_length), kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 1658481\n", + "random guess loss: 10.82490511970208\n", + "tensor(10.9950, device='cuda:0', grad_fn=)\n", + "torch.Size([4, 1024]) tensor(10.9950, device='cuda:0', grad_fn=)\n", + "tensor([[11.4311, 11.6552, 10.2010, ..., 10.5417, 10.6344, 11.4137],\n", + " [11.0913, 10.3161, 10.8965, ..., 11.6884, 11.4491, 10.4440],\n", + " [12.3048, 10.9655, 10.6260, ..., 10.9756, 11.2433, 10.6060],\n", + " [10.5069, 10.6218, 11.0385, ..., 11.9397, 10.6035, 10.4034]],\n", + " device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "from data import get_batch, enc\n", + "import tiktoken\n", + "import math\n", + "\n", + "x, y = get_batch(batch_size, context_length, 'train')\n", + "vocab_size = tiktoken.get_encoding(\"gpt2\").n_vocab\n", + "model = BigramLanguageModel(vocab_size).cuda()\n", + "loss = model.loss(x.cuda(), y.cuda())\n", + "print('random guess loss:', -math.log(1/vocab_size))\n", + "print(loss)\n", + "loss_per_token = model.loss_per_token(x.cuda(), y.cuda())\n", + "print(loss_per_token.shape, loss_per_token.mean())\n", + "print(loss_per_token)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input [' must', '\\n', 'In', ' that']\n", + "output [' must', '\\n', 'In', ' that', ' Calculator', ' HPV', 'Empty', ' LW', ' Seconds', ' Infinite', ' payoff', 'ste']\n", + "Gold label [' must', '\\n', 'In', ' that', ' be', ' made', ' more', ' bitter', '.', ' Fear', ' o', \"'\", 'ers', 'h', 'ades', ' me', ':', '\\n', 'Good', ' expedition', ' be', ' my', ' friend', ',', ' and', ' comfort', '\\n', 'The', ' gracious', ' queen', ',', ' part', ' of', ' his', ' theme', ',', ' but', ' nothing', '\\n', 'Of', ' his', ' ill', '-', 'ta', \"'\", 'en', ' suspicion', '!', ' Come', ',', ' Cam', 'illo', ';', '\\n', 'I', ' will', ' respect', ' thee', ' as', ' a', ' father', ' if', '\\n', 'Th', 'ou', ' bear', \"'s\", 't', ' my', ' life', ' off', ' hence', ':', ' let', ' us', ' avoid', '.', '\\n', '\\n', 'C', 'AM', 'ILL', 'O', ':', '\\n', 'It', ' is', ' in', ' mine', ' authority', ' to', ' command', '\\n', 'The', ' keys', ' of', ' all', ' the', ' post', 'ern', 's', ':', ' please', ' your', ' high', 'ness', '\\n', 'To', ' take', ' the', ' urgent', ' hour', '.', ' Come', ',', ' sir', ',', ' away', '.', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'Take', ' the', ' boy', ' to', ' you', ':', ' he', ' so', ' troubles', ' me', ',', '\\n', \"'\", 'T', 'is', ' past', ' enduring', '.', '\\n', '\\n', 'First', ' Lady', ':', '\\n', 'Come', ',', ' my', ' gracious', ' lord', ',', '\\n', 'Sh', 'all', ' I', ' be', ' your', ' play', 'f', 'ellow', '?', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'No', ',', ' I', \"'ll\", ' none', ' of', ' you', '.', '\\n', '\\n', 'First', ' Lady', ':', '\\n', 'Why', ',', ' my', ' sweet', ' lord', '?', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'You', \"'ll\", ' kiss', ' me', ' hard', ' and', ' speak', ' to', ' me', ' as', ' if', '\\n', 'I', ' were', ' a', ' baby', ' still', '.', ' I', ' love', ' you', ' better', '.', '\\n', '\\n', 'Second', ' Lady', ':', '\\n', 'And', ' why', ' so', ',', ' my', ' lord', '?', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'Not', ' for', ' because', '\\n', 'Your', ' brow', 's', ' are', ' black', 'er', ';', ' yet', ' black', ' brow', 's', ',', ' they', ' say', ',', '\\n', 'Bec', 'ome', ' some', ' women', ' best', ',', ' so', ' that', ' there', ' be', ' not', '\\n', 'Too', ' much', ' hair', ' there', ',', ' but', ' in', ' a', ' semic', 'irc', 'le', '\\n', 'Or', ' a', ' half', '-', 'moon', ' made', ' with', ' a', ' pen', '.', '\\n', '\\n', 'Second', ' Lady', ':', '\\n', 'Who', ' taught', ' you', ' this', '?', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'I', ' learnt', ' it', ' out', ' of', ' women', \"'s\", ' faces', '.', ' Pr', 'ay', ' now', '\\n', 'What', ' colour', ' are', ' your', ' eyebrows', '?', '\\n', '\\n', 'First', ' Lady', ':', '\\n', 'Blue', ',', ' my', ' lord', '.', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'N', 'ay', ',', ' that', \"'s\", ' a', ' mock', ':', ' I', ' have', ' seen', ' a', ' lady', \"'s\", ' nose', '\\n', 'That', ' has', ' been', ' blue', ',', ' but', ' not', ' her', ' eyebrows', '.', '\\n', '\\n', 'First', ' Lady', ':', '\\n', 'H', 'ark', ' ye', ';', '\\n', 'The', ' queen', ' your', ' mother', ' rounds', ' ap', 'ace', ':', ' we', ' shall', '\\n', 'Present', ' our', ' services', ' to', ' a', ' fine', ' new', ' prince', '\\n', 'One', ' of', ' these', ' days', ';', ' and', ' then', ' you', \"'\", 'ld', ' want', 'on', ' with', ' us', ',', '\\n', 'If', ' we', ' would', ' have', ' you', '.', '\\n', '\\n', 'Second', ' Lady', ':', '\\n', 'She', ' is', ' spread', ' of', ' late', '\\n', 'Int', 'o', ' a', ' good', 'ly', ' bulk', ':', ' good', ' time', ' encounter', ' her', '!', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'What', ' wisdom', ' stir', 's', ' amongst', ' you', '?', ' Come', ',', ' sir', ',', ' now', '\\n', 'I', ' am', ' for', ' you', ' again', ':', ' pray', ' you', ',', ' sit', ' by', ' us', ',', '\\n', 'And', ' tell', \" '\", 's', ' a', ' tale', '.', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'M', 'erry', ' or', ' sad', ' shall', \"'t\", ' be', '?', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'As', ' merry', ' as', ' you', ' will', '.', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'A', ' sad', ' tale', \"'s\", ' best', ' for', ' winter', ':', ' I', ' have', ' one', '\\n', 'Of', ' sprites', ' and', ' goblins', '.', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'Let', \"'s\", ' have', ' that', ',', ' good', ' sir', '.', '\\n', 'Come', ' on', ',', ' sit', ' down', ':', ' come', ' on', ',', ' and', ' do', ' your', ' best', '\\n', 'To', ' fright', ' me', ' with', ' your', ' sprites', ';', ' you', \"'re\", ' powerful', ' at', ' it', '.', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'There', ' was', ' a', ' man', '--', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'N', 'ay', ',', ' come', ',', ' sit', ' down', ';', ' then', ' on', '.', '\\n', '\\n', 'M', 'AM', 'ILL', 'I', 'US', ':', '\\n', 'D', 'w', 'elt', ' by', ' a', ' church', 'yard', ':', ' I', ' will', ' tell', ' it', ' softly', ';', '\\n', 'Y', 'ond', ' cr', 'ickets', ' shall', ' not', ' hear', ' it', '.', '\\n', '\\n', 'HER', 'M', 'ION', 'E', ':', '\\n', 'Come', ' on', ',', ' then', ',', '\\n', 'And', ' give', \"'t\", ' me', ' in', ' mine', ' ear', '.', '\\n', '\\n', 'LE', 'ONT', 'ES', ':', '\\n', 'Was', ' he', ' met', ' there', '?', ' his', ' train', '?', ' Cam', 'illo', ' with', ' him', '?', '\\n', '\\n', 'First', ' Lord', ':', '\\n', 'Behind', ' the', ' tu', 'ft', ' of', ' p', 'ines', ' I', ' met', ' them', ';', ' never', '\\n', 'S', 'aw', ' I', ' men', ' sc', 'our', ' so', ' on', ' their', ' way', ':', ' I', ' eyed', ' them', '\\n', 'Even', ' to', ' their', ' ships', '.', '\\n', '\\n', 'LE', 'ONT', 'ES', ':', '\\n', 'How', ' bl', 'est', ' am', ' I', '\\n', 'In', ' my', ' just', ' cens', 'ure', ',', ' in', ' my', ' true', ' opinion', '!', '\\n', 'Al', 'ack', ',', ' for', ' lesser', ' knowledge', '!', ' how', ' acc', 'ursed', '\\n', 'In', ' being', ' so', ' bl', 'est', '!', ' There', ' may', ' be', ' in', ' the', ' cup', '\\n', 'A', ' spider', ' steep', \"'d\", ',', ' and', ' one', ' may', ' drink', ',', ' depart', ',', '\\n', 'And', ' yet', ' partake', ' no', ' venom', ',', ' for', ' his', ' knowledge', '\\n', 'Is', ' not', ' infected', ':', ' but', ' if', ' one', ' present', '\\n', 'The', ' abhor', 'r', \"'d\", ' ingredient', ' to', ' his', ' eye', ',', ' make', ' known', '\\n', 'How', ' he', ' hath', ' drunk', ',', ' he', ' cracks', ' his', ' gorge', ',', ' his', ' sides', ',', '\\n', 'With', ' violent', ' he', 'fts', '.', ' I', ' have', ' drunk', ',', '\\n', 'and', ' seen', ' the', ' spider', '.', '\\n', 'Cam', 'illo', ' was', ' his', ' help', ' in', ' this', ',', ' his', ' p', 'ander', ':', '\\n', 'There', ' is', ' a', ' plot', ' against', ' my', ' life', ',', ' my', ' crown', ';', '\\n', 'All', \"'s\", ' true', ' that', ' is', ' mist', 'r', 'usted', ':', ' that', ' false', ' villain', '\\n', 'Wh', 'om', ' I', ' employ', \"'d\", ' was', ' pre', '-', 'employ', \"'d\", ' by', ' him', ':', '\\n', 'He', ' has', ' discover', \"'d\", ' my', ' design', ',', ' and', ' I', '\\n', 'Rem', 'ain', ' a', ' pinch', \"'d\", ' thing', ';', ' yea', ',', ' a', ' very', ' trick', '\\n', 'For', ' them', ' to', ' play', ' at', ' will', '.', ' How', ' came', ' the', ' post', 'ern', 's', '\\n', 'So', ' easily', ' open', '?', '\\n', '\\n', 'First', ' Lord', ':', '\\n', 'By', ' his', ' great', ' authority', ';', '\\n', 'Which', ' often', ' hath', ' no', ' less', ' prevail', \"'d\", ' than', ' so', '\\n', 'On', ' your', ' command', '.', '\\n', '\\n', 'LE', 'ONT', 'ES', ':', '\\n', 'I', ' know', \"'t\", ' too', ' well', '.', '\\n', 'Give', ' me', ' the', ' boy', ':', ' I', ' am', ' glad', ' you', ' did', ' not', ' nurse']\n" + ] + } + ], + "source": [ + "input_tokens = x[0, :4].unsqueeze(0).cuda()\n", + "max_new_token = 8\n", + "generated_tokens = model.generate(input_tokens, max_new_token)\n", + "print('input', [enc.decode([i.item()]) for i in input_tokens[0]])\n", + "print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])\n", + "print('Gold label', [enc.decode([i.item()]) for i in x[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "steps: 0 loss: 11.002180099487305\n", + "steps: 100 loss: 10.177762985229492\n", + "steps: 200 loss: 9.035687446594238\n", + "steps: 300 loss: 7.741647243499756\n", + "steps: 400 loss: 6.681490898132324\n", + "steps: 499 loss: 6.057015895843506\n" + ] + } + ], + "source": [ + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", + "batch_size = 32\n", + "context_length = 1024\n", + "iterations = 500\n", + "for steps in range(iterations):\n", + " x, y = get_batch(batch_size, context_length, 'train')\n", + " # print(x[0], y[0])\n", + " x, y = x.cuda(), y.cuda()\n", + " loss = model.loss(x, y)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " if steps % 100 == 0:\n", + " print('steps:', steps, 'loss:', loss.item())\n", + "print('steps:', steps, 'loss:', loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input ['I', ' have', ' a', ' motion']\n", + "output ['I', ' have', ' a', ' motion', ' we', ' Glou', ' depletion', 'Pont', 'ndra', ' vividly', 'object', 'Red']\n", + "Gold label ['I', ' have', ' a', ' motion', ' much', ' imports', ' your', ' good', ';', '\\n', 'Whe', 'ret', 'o', ' if', ' you', \"'ll\", ' a', ' willing', ' ear', ' incl', 'ine', ',', '\\n', 'What', \"'s\", ' mine', ' is', ' yours', ' and', ' what', ' is', ' yours', ' is', ' mine', '.', '\\n', 'So', ',', ' bring', ' us', ' to', ' our', ' palace', ';', ' where', ' we', \"'ll\", ' show', '\\n', 'What', \"'s\", ' yet', ' behind', ',', ' that', \"'s\", ' meet', ' you', ' all', ' should', ' know', '.', '\\n', '\\n', 'SL', 'Y', ':', '\\n', 'I', \"'ll\", ' p', 'hee', 'ze', ' you', ',', ' in', ' faith', '.', '\\n', '\\n', 'Host', 'ess', ':', '\\n', 'A', ' pair', ' of', ' stocks', ',', ' you', ' rogue', '!', '\\n', '\\n', 'SL', 'Y', ':', '\\n', 'Ye', ' are', ' a', ' baggage', ':', ' the', ' S', 'lys', ' are', ' no', ' rog', 'ues', ';', ' look', ' in', '\\n', 'the', ' chron', 'icles', ';', ' we', ' came', ' in', ' with', ' Richard', ' Conquer', 'or', '.', '\\n', 'Therefore', ' p', 'aucas', ' pall', 'ab', 'ris', ';', ' let', ' the', ' world', ' slide', ':', ' s', 'essa', '!', '\\n', '\\n', 'Host', 'ess', ':', '\\n', 'You', ' will', ' not', ' pay', ' for', ' the', ' glasses', ' you', ' have', ' burst', '?', '\\n', '\\n', 'SL', 'Y', ':', '\\n', 'No', ',', ' not', ' a', ' den', 'ier', '.', ' Go', ' by', ',', ' Jer', 'on', 'im', 'y', ':', ' go', ' to', ' thy', ' cold', '\\n', 'bed', ',', ' and', ' warm', ' thee', '.', '\\n', '\\n', 'Host', 'ess', ':', '\\n', 'I', ' know', ' my', ' remedy', ';', ' I', ' must', ' go', ' fetch', ' the', '\\n', 'third', '--', 'borough', '.', '\\n', '\\n', 'SL', 'Y', ':', '\\n', 'Third', ',', ' or', ' fourth', ',', ' or', ' fifth', ' borough', ',', ' I', \"'ll\", ' answer', ' him', '\\n', 'by', ' law', ':', ' I', \"'ll\", ' not', ' bud', 'ge', ' an', ' inch', ',', ' boy', ':', ' let', ' him', ' come', ',', '\\n', 'and', ' kindly', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'Hun', 'ts', 'man', ',', ' I', ' charge', ' thee', ',', ' tender', ' well', ' my', ' h', 'ounds', ':', '\\n', 'Br', 'ach', ' Mer', 'rim', 'an', ',', ' the', ' poor', ' cur', ' is', ' emb', 'oss', \"'d\", ';', '\\n', 'And', ' couple', ' Cl', 'owder', ' with', ' the', ' deep', '--', 'mouth', \"'d\", ' br', 'ach', '.', '\\n', 'S', 'aw', \"'s\", 't', ' thou', ' not', ',', ' boy', ',', ' how', ' Silver', ' made', ' it', ' good', '\\n', 'At', ' the', ' hedge', '-', 'cor', 'ner', ',', ' in', ' the', ' cold', 'est', ' fault', '?', '\\n', 'I', ' would', ' not', ' lose', ' the', ' dog', ' for', ' twenty', ' pound', '.', '\\n', '\\n', 'First', ' Hunts', 'man', ':', '\\n', 'Why', ',', ' Bel', 'man', ' is', ' as', ' good', ' as', ' he', ',', ' my', ' lord', ';', '\\n', 'He', ' cried', ' upon', ' it', ' at', ' the', ' me', 'rest', ' loss', '\\n', 'And', ' twice', ' to', '-', 'day', ' pick', \"'d\", ' out', ' the', ' dull', 'est', ' scent', ':', '\\n', 'Trust', ' me', ',', ' I', ' take', ' him', ' for', ' the', ' better', ' dog', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'Th', 'ou', ' art', ' a', ' fool', ':', ' if', ' Echo', ' were', ' as', ' fleet', ',', '\\n', 'I', ' would', ' esteem', ' him', ' worth', ' a', ' dozen', ' such', '.', '\\n', 'But', ' sup', ' them', ' well', ' and', ' look', ' unto', ' them', ' all', ':', '\\n', 'To', '-', 'morrow', ' I', ' intend', ' to', ' hunt', ' again', '.', '\\n', '\\n', 'First', ' Hunts', 'man', ':', '\\n', 'I', ' will', ',', ' my', ' lord', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'What', \"'s\", ' here', '?', ' one', ' dead', ',', ' or', ' drunk', '?', ' See', ',', ' d', 'oth', ' he', ' breathe', '?', '\\n', '\\n', 'Second', ' Hunts', 'man', ':', '\\n', 'He', ' breat', 'hes', ',', ' my', ' lord', '.', ' Were', ' he', ' not', ' warm', \"'d\", ' with', ' ale', ',', '\\n', 'This', ' were', ' a', ' bed', ' but', ' cold', ' to', ' sleep', ' so', ' sound', 'ly', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'O', ' monstrous', ' beast', '!', ' how', ' like', ' a', ' sw', 'ine', ' he', ' lies', '!', '\\n', 'G', 'rim', ' death', ',', ' how', ' foul', ' and', ' lo', 'ath', 'some', ' is', ' th', 'ine', ' image', '!', '\\n', 'S', 'irs', ',', ' I', ' will', ' practise', ' on', ' this', ' drunken', ' man', '.', '\\n', 'What', ' think', ' you', ',', ' if', ' he', ' were', ' convey', \"'d\", ' to', ' bed', ',', '\\n', 'Wra', 'pp', \"'d\", ' in', ' sweet', ' clothes', ',', ' rings', ' put', ' upon', ' his', ' fingers', ',', '\\n', 'A', ' most', ' delicious', ' banquet', ' by', ' his', ' bed', ',', '\\n', 'And', ' brave', ' attendants', ' near', ' him', ' when', ' he', ' wakes', ',', '\\n', 'Would', ' not', ' the', ' begg', 'ar', ' then', ' forget', ' himself', '?', '\\n', '\\n', 'First', ' Hunts', 'man', ':', '\\n', 'Bel', 'ieve', ' me', ',', ' lord', ',', ' I', ' think', ' he', ' cannot', ' choose', '.', '\\n', '\\n', 'Second', ' Hunts', 'man', ':', '\\n', 'It', ' would', ' seem', ' strange', ' unto', ' him', ' when', ' he', ' w', 'aked', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'Even', ' as', ' a', ' flattering', ' dream', ' or', ' worthless', ' fancy', '.', '\\n', 'Then', ' take', ' him', ' up', ' and', ' manage', ' well', ' the', ' j', 'est', ':', '\\n', 'C', 'arry', ' him', ' gently', ' to', ' my', ' faire', 'st', ' chamber', '\\n', 'And', ' hang', ' it', ' round', ' with', ' all', ' my', ' want', 'on', ' pictures', ':', '\\n', 'Bal', 'm', ' his', ' foul', ' head', ' in', ' warm', ' distilled', ' waters', '\\n', 'And', ' burn', ' sweet', ' wood', ' to', ' make', ' the', ' lodging', ' sweet', ':', '\\n', 'Pro', 'c', 'ure', ' me', ' music', ' ready', ' when', ' he', ' wakes', ',', '\\n', 'To', ' make', ' a', ' d', 'ul', 'c', 'et', ' and', ' a', ' heavenly', ' sound', ';', '\\n', 'And', ' if', ' he', ' chance', ' to', ' speak', ',', ' be', ' ready', ' straight', '\\n', 'And', ' with', ' a', ' low', ' sub', 'missive', ' reverence', '\\n', 'Say', \" '\", 'What', ' is', ' it', ' your', ' honour', ' will', ' command', \"?'\", '\\n', 'Let', ' one', ' attend', ' him', ' with', ' a', ' silver', ' basin', '\\n', 'Full', ' of', ' rose', '-', 'water', ' and', ' best', 'rew', \"'d\", ' with', ' flowers', ',', '\\n', 'Another', ' bear', ' the', ' e', 'wer', ',', ' the', ' third', ' a', ' diaper', ',', '\\n', 'And', ' say', \" '\", 'Will', \"'t\", ' please', ' your', ' lords', 'hip', ' cool', ' your', ' hands', \"?'\", '\\n', 'Some', ' one', ' be', ' ready', ' with', ' a', ' costly', ' suit', '\\n', 'And', ' ask', ' him', ' what', ' apparel', ' he', ' will', ' wear', ';', '\\n', 'Another', ' tell', ' him', ' of', ' his', ' h', 'ounds', ' and', ' horse', ',', '\\n', 'And', ' that', ' his', ' lady', ' mourn', 's', ' at', ' his', ' disease', ':', '\\n', 'Pers', 'u', 'ade', ' him', ' that', ' he', ' hath', ' been', ' lun', 'atic', ';', '\\n', 'And', ' when', ' he', ' says', ' he', ' is', ',', ' say', ' that', ' he', ' dreams', ',', '\\n', 'For', ' he', ' is', ' nothing', ' but', ' a', ' mighty', ' lord', '.', '\\n', 'This', ' do', ' and', ' do', ' it', ' kindly', ',', ' gentle', ' sir', 's', ':', '\\n', 'It', ' will', ' be', ' past', 'ime', ' passing', ' excellent', ',', '\\n', 'If', ' it', ' be', ' husband', 'ed', ' with', ' modesty', '.', '\\n', '\\n', 'First', ' Hunts', 'man', ':', '\\n', 'My', ' lord', ',', ' I', ' warrant', ' you', ' we', ' will', ' play', ' our', ' part', ',', '\\n', 'As', ' he', ' shall', ' think', ' by', ' our', ' true', ' diligence', '\\n', 'He', ' is', ' no', ' less', ' than', ' what', ' we', ' say', ' he', ' is', '.', '\\n', '\\n', 'Lord', ':', '\\n', 'Take', ' him', ' up', ' gently', ' and', ' to', ' bed', ' with', ' him', ';', '\\n', 'And', ' each', ' one', ' to', ' his', ' office', ' when', ' he', ' wakes', '.', '\\n', 'Sir', 'rah', ',', ' go', ' see', ' what', ' trumpet', \" '\", 'tis', ' that', ' sounds', ':', '\\n', 'Bel', 'ike', ',', ' some', ' noble', ' gentleman', ' that', ' means', ',', '\\n', 'T', 'rave', 'lling', ' some']\n" + ] + } + ], + "source": [ + "input_tokens = x[0, :4].unsqueeze(0).cuda()\n", + "max_new_token = 8\n", + "generated_tokens = model.generate(input_tokens, max_new_token)\n", + "print('input', [enc.decode([i.item()]) for i in input_tokens[0]])\n", + "print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])\n", + "print('Gold label', [enc.decode([i.item()]) for i in x[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seen tokens: 16384000\n" + ] + } + ], + "source": [ + "print('seen tokens: ', batch_size * context_length * iterations)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(10.8249, device='cuda:0')\n", + "48\n", + "tensor(4.5572, device='cuda:0')\n" + ] + } + ], + "source": [ + "from ngram import Ngram\n", + "from data import text, enc\n", + "import torch\n", + "vocab = list(range(enc.n_vocab))\n", + "context_lengh = 16\n", + "ngram = Ngram(2, vocab)\n", + "inputs = [enc.encode(text)[:context_lengh]]\n", + "targets = torch.LongTensor([enc.encode(text)[1:context_lengh+1]]).cuda()\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "epochs = (batch_size * context_length * iterations) // len(enc.encode(text))\n", + "ngram = Ngram(2, vocab)\n", + "print(epochs)\n", + "for epoch in range(epochs):\n", + " ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(10.8249, device='cuda:0')\n", + "tensor(3.4560, device='cuda:0')\n" + ] + } + ], + "source": [ + "ngram = Ngram(2, vocab, 1e-3)\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(10.8249, device='cuda:0')\n", + "tensor(2.3174, device='cuda:0')\n" + ] + } + ], + "source": [ + "ngram = Ngram(4, vocab, 1e-5)\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.09853373047652528" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "math.exp(-loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# ngram.ngram" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playground/GPT DEV_2_e2epipeline_baseline_char.ipynb b/playground/GPT DEV_2_e2epipeline_baseline_char.ipynb new file mode 100644 index 0000000..9f23a55 --- /dev/null +++ b/playground/GPT DEV_2_e2epipeline_baseline_char.ipynb @@ -0,0 +1,935 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "- [A Recipe for Training Neural Networks\n", + "](https://karpathy.github.io/2019/04/25/recipe/)\n", + "- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)\n", + "- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)\n", + "- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up the end-to-end training/evaluation skeleton + get dumb baselines" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn import functional as F\n", + "\n", + "torch.manual_seed(1337)\n", + "\n", + "class BigramLanguageModel(nn.Module):\n", + " def __init__(self, vocab_size):\n", + " super().__init__()\n", + " self.bigram_table = nn.Embedding(vocab_size, vocab_size)\n", + " # self.token_embedding_table = nn.Embedding(vocab_size, 16)\n", + " # self.linear = nn.Linear(16, vocab_size)\n", + " print('number of parameters:', sum(p.numel() for p in self.parameters()))\n", + " \n", + " def forward(self, token_indexes):\n", + " # token_index: (batch_size, sequence_length)\n", + " logits = self.bigram_table(token_indexes)\n", + "\n", + " # embedding = self.token_embedding_table(token_indexes)\n", + " # logits = self.linear(embedding)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " return logits\n", + "\n", + " def loss_per_token(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length),\n", + " reduction='none'\n", + " )\n", + " # loss: (batch_size*sequence_length)\n", + " return loss.view(batch_size, sequence_length)\n", + " \n", + " def loss(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length)\n", + " )\n", + " # loss: scalar\n", + " return loss\n", + " \n", + " def generate(self, token_indexes, max_new_tokens):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.shape\n", + " for _ in range(max_new_tokens):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " next_token_logits = logits[:, -1, :]\n", + " # next_token_logits: (batch_size, vocab_size)\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " # next_token_probs: (batch_size, vocab_size)\n", + " next_token = torch.multinomial(next_token_probs, num_samples=1)\n", + " # next_token: (batch_size, 1)\n", + " token_indexes = torch.cat([token_indexes, next_token], dim=1)\n", + " # token_indexes: (batch_size, sequence_length+1)\n", + " return token_indexes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "65\n", + "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n", + "[20, 43, 50, 50, 53, 1, 51, 63, 1, 52, 39, 51, 43, 1, 47, 57, 1, 49, 43, 52, 53]\n", + "['H', 'e', 'l', 'l', 'o', ' ', 'm', 'y', ' ', 'n', 'a', 'm', 'e', ' ', 'i', 's', ' ', 'k', 'e', 'n', 'o']\n" + ] + } + ], + "source": [ + "from data_char import text, CharTokenizer\n", + "\n", + "tokenizer = CharTokenizer(text)\n", + "print(tokenizer.n_vocab)\n", + "print(tokenizer.vocab)\n", + "print(tokenizer.encode('Hello my name is keno'))\n", + "print(tokenizer.decode(tokenizer.encode('Hello my name is keno')))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def rand_int_test(cls, low, high, shape, kwargs):\n", + " layer = cls(**kwargs).cuda()\n", + " random_input = torch.randint(low, high, shape).cuda()\n", + " print('input shape:', random_input.shape)\n", + " output = layer(random_input)\n", + " print('output shape:', output.shape)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 65536\n", + "input shape: torch.Size([4, 1024])\n", + "output shape: torch.Size([4, 1024, 256])\n" + ] + } + ], + "source": [ + "test_cls = BigramLanguageModel\n", + "batch_size = 4\n", + "context_length = 1024\n", + "vocab_size = 256\n", + "\n", + "kwargs = {'vocab_size': vocab_size}\n", + "output = rand_int_test(test_cls, 0, vocab_size, (batch_size, context_length), kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 4225\n", + "random guess loss: 4.174387269895637\n", + "tensor(4.7202, device='cuda:0', grad_fn=)\n", + "torch.Size([4, 1024]) tensor(4.7202, device='cuda:0', grad_fn=)\n", + "tensor([[4.6262, 5.8574, 5.4585, ..., 5.5064, 5.7386, 3.9034],\n", + " [5.4831, 3.7372, 4.9155, ..., 3.7848, 4.7020, 4.1312],\n", + " [4.7929, 3.7371, 5.3320, ..., 4.7322, 3.9989, 4.3654],\n", + " [4.2844, 5.6883, 4.1599, ..., 5.5120, 5.1844, 3.5611]],\n", + " device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "from data_char import get_batch, enc\n", + "import math\n", + "\n", + "x, y = get_batch(batch_size, context_length, 'train')\n", + "vocab_size = enc.n_vocab\n", + "model = BigramLanguageModel(vocab_size).cuda()\n", + "loss = model.loss(x.cuda(), y.cuda())\n", + "print('random guess loss:', -math.log(1/vocab_size))\n", + "print(loss)\n", + "loss_per_token = model.loss_per_token(x.cuda(), y.cuda())\n", + "print(loss_per_token.shape, loss_per_token.mean())\n", + "print(loss_per_token)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input [['y'], [' '], ['s'], ['t']]\n", + "output [['y'], [' '], ['s'], ['t'], ['?'], ['v'], ['d'], ['x'], ['x'], ['b'], ['D'], ['y']]\n", + "Gold label [['y'], [' '], ['s'], ['t'], ['i'], ['n'], ['g'], [' '], ['t'], ['o'], [' '], ['h'], ['u'], ['r'], ['t'], [','], ['\\n'], ['Y'], ['e'], ['t'], [' '], ['l'], ['o'], ['o'], ['k'], [' '], ['t'], ['o'], [' '], ['h'], ['a'], ['v'], ['e'], [' '], ['t'], ['h'], ['e'], ['m'], [' '], ['b'], ['u'], ['z'], ['z'], [' '], ['t'], ['o'], [' '], ['o'], ['f'], ['f'], ['e'], ['n'], ['d'], [' '], ['t'], ['h'], ['i'], ['n'], ['e'], [' '], ['e'], ['a'], ['r'], ['s'], ['.'], ['\\n'], ['F'], ['i'], ['r'], ['s'], ['t'], [' '], ['w'], ['i'], ['l'], ['l'], [' '], ['I'], [' '], ['s'], ['e'], ['e'], [' '], ['t'], ['h'], ['e'], [' '], ['c'], ['o'], ['r'], ['o'], ['n'], ['a'], ['t'], ['i'], ['o'], ['n'], [';'], ['\\n'], ['A'], ['n'], ['d'], [' '], ['t'], ['h'], ['e'], ['n'], [' '], ['t'], ['o'], [' '], ['B'], ['r'], ['i'], ['t'], ['t'], ['a'], ['n'], ['y'], [' '], ['I'], [\"'\"], ['l'], ['l'], [' '], ['c'], ['r'], ['o'], ['s'], ['s'], [' '], ['t'], ['h'], ['e'], [' '], ['s'], ['e'], ['a'], [','], ['\\n'], ['T'], ['o'], [' '], ['e'], ['f'], ['f'], ['e'], ['c'], ['t'], [' '], ['t'], ['h'], ['i'], ['s'], [' '], ['m'], ['a'], ['r'], ['r'], ['i'], ['a'], ['g'], ['e'], [','], [' '], ['s'], ['o'], [' '], ['i'], ['t'], [' '], ['p'], ['l'], ['e'], ['a'], ['s'], ['e'], [' '], ['m'], ['y'], [' '], ['l'], ['o'], ['r'], ['d'], ['.'], ['\\n'], ['\\n'], ['E'], ['D'], ['W'], ['A'], ['R'], ['D'], [':'], ['\\n'], ['E'], ['v'], ['e'], ['n'], [' '], ['a'], ['s'], [' '], ['t'], ['h'], ['o'], ['u'], [' '], ['w'], ['i'], ['l'], ['t'], [','], [' '], ['s'], ['w'], ['e'], ['e'], ['t'], [' '], ['W'], ['a'], ['r'], ['w'], ['i'], ['c'], ['k'], [','], [' '], ['l'], ['e'], ['t'], [' '], ['i'], ['t'], [' '], ['b'], ['e'], [';'], ['\\n'], ['F'], ['o'], ['r'], [' '], ['i'], ['n'], [' '], ['t'], ['h'], ['y'], [' '], ['s'], ['h'], ['o'], ['u'], ['l'], ['d'], ['e'], ['r'], [' '], ['d'], ['o'], [' '], ['I'], [' '], ['b'], ['u'], ['i'], ['l'], ['d'], [' '], ['m'], ['y'], [' '], ['s'], ['e'], ['a'], ['t'], [','], ['\\n'], ['A'], ['n'], ['d'], [' '], ['n'], ['e'], ['v'], ['e'], ['r'], [' '], ['w'], ['i'], ['l'], ['l'], [' '], ['I'], [' '], ['u'], ['n'], ['d'], ['e'], ['r'], ['t'], ['a'], ['k'], ['e'], [' '], ['t'], ['h'], ['e'], [' '], ['t'], ['h'], ['i'], ['n'], ['g'], ['\\n'], ['W'], ['h'], ['e'], ['r'], ['e'], ['i'], ['n'], [' '], ['t'], ['h'], ['y'], [' '], ['c'], ['o'], ['u'], ['n'], ['s'], ['e'], ['l'], [' '], ['a'], ['n'], ['d'], [' '], ['c'], ['o'], ['n'], ['s'], ['e'], ['n'], ['t'], [' '], ['i'], ['s'], [' '], ['w'], ['a'], ['n'], ['t'], ['i'], ['n'], ['g'], ['.'], ['\\n'], ['R'], ['i'], ['c'], ['h'], ['a'], ['r'], ['d'], [','], [' '], ['I'], [' '], ['w'], ['i'], ['l'], ['l'], [' '], ['c'], ['r'], ['e'], ['a'], ['t'], ['e'], [' '], ['t'], ['h'], ['e'], ['e'], [' '], ['D'], ['u'], ['k'], ['e'], [' '], ['o'], ['f'], [' '], ['G'], ['l'], ['o'], ['u'], ['c'], ['e'], ['s'], ['t'], ['e'], ['r'], [','], ['\\n'], ['A'], ['n'], ['d'], [' '], ['G'], ['e'], ['o'], ['r'], ['g'], ['e'], [','], [' '], ['o'], ['f'], [' '], ['C'], ['l'], ['a'], ['r'], ['e'], ['n'], ['c'], ['e'], [':'], [' '], ['W'], ['a'], ['r'], ['w'], ['i'], ['c'], ['k'], [','], [' '], ['a'], ['s'], [' '], ['o'], ['u'], ['r'], ['s'], ['e'], ['l'], ['f'], [','], ['\\n'], ['S'], ['h'], ['a'], ['l'], ['l'], [' '], ['d'], ['o'], [' '], ['a'], ['n'], ['d'], [' '], ['u'], ['n'], ['d'], ['o'], [' '], ['a'], ['s'], [' '], ['h'], ['i'], ['m'], [' '], ['p'], ['l'], ['e'], ['a'], ['s'], ['e'], ['t'], ['h'], [' '], ['b'], ['e'], ['s'], ['t'], ['.'], ['\\n'], ['\\n'], ['R'], ['I'], ['C'], ['H'], ['A'], ['R'], ['D'], [':'], ['\\n'], ['L'], ['e'], ['t'], [' '], ['m'], ['e'], [' '], ['b'], ['e'], [' '], ['D'], ['u'], ['k'], ['e'], [' '], ['o'], ['f'], [' '], ['C'], ['l'], ['a'], ['r'], ['e'], ['n'], ['c'], ['e'], [','], [' '], ['G'], ['e'], ['o'], ['r'], ['g'], ['e'], [' '], ['o'], ['f'], [' '], ['G'], ['l'], ['o'], ['u'], ['c'], ['e'], ['s'], ['t'], ['e'], ['r'], [';'], ['\\n'], ['F'], ['o'], ['r'], [' '], ['G'], ['l'], ['o'], ['u'], ['c'], ['e'], ['s'], ['t'], ['e'], ['r'], [\"'\"], ['s'], [' '], ['d'], ['u'], ['k'], ['e'], ['d'], ['o'], ['m'], [' '], ['i'], ['s'], [' '], ['t'], ['o'], ['o'], [' '], ['o'], ['m'], ['i'], ['n'], ['o'], ['u'], ['s'], ['.'], ['\\n'], ['\\n'], ['W'], ['A'], ['R'], ['W'], ['I'], ['C'], ['K'], [':'], ['\\n'], ['T'], ['u'], ['t'], [','], [' '], ['t'], ['h'], ['a'], ['t'], [\"'\"], ['s'], [' '], ['a'], [' '], ['f'], ['o'], ['o'], ['l'], ['i'], ['s'], ['h'], [' '], ['o'], ['b'], ['s'], ['e'], ['r'], ['v'], ['a'], ['t'], ['i'], ['o'], ['n'], [':'], ['\\n'], ['R'], ['i'], ['c'], ['h'], ['a'], ['r'], ['d'], [','], [' '], ['b'], ['e'], [' '], ['D'], ['u'], ['k'], ['e'], [' '], ['o'], ['f'], [' '], ['G'], ['l'], ['o'], ['u'], ['c'], ['e'], ['s'], ['t'], ['e'], ['r'], ['.'], [' '], ['N'], ['o'], ['w'], [' '], ['t'], ['o'], [' '], ['L'], ['o'], ['n'], ['d'], ['o'], ['n'], [','], ['\\n'], ['T'], ['o'], [' '], ['s'], ['e'], ['e'], [' '], ['t'], ['h'], ['e'], ['s'], ['e'], [' '], ['h'], ['o'], ['n'], ['o'], ['u'], ['r'], ['s'], [' '], ['i'], ['n'], [' '], ['p'], ['o'], ['s'], ['s'], ['e'], ['s'], ['s'], ['i'], ['o'], ['n'], ['.'], ['\\n'], ['3'], [' '], ['K'], ['I'], ['N'], ['G'], [' '], ['H'], ['E'], ['N'], ['R'], ['Y'], [' '], ['V'], ['I'], ['\\n'], ['\\n'], ['F'], ['i'], ['r'], ['s'], ['t'], [' '], ['K'], ['e'], ['e'], ['p'], ['e'], ['r'], [':'], ['\\n'], ['U'], ['n'], ['d'], ['e'], ['r'], [' '], ['t'], ['h'], ['i'], ['s'], [' '], ['t'], ['h'], ['i'], ['c'], ['k'], ['-'], ['g'], ['r'], ['o'], ['w'], ['n'], [' '], ['b'], ['r'], ['a'], ['k'], ['e'], [' '], ['w'], ['e'], [\"'\"], ['l'], ['l'], [' '], ['s'], ['h'], ['r'], ['o'], ['u'], ['d'], [' '], ['o'], ['u'], ['r'], ['s'], ['e'], ['l'], ['v'], ['e'], ['s'], [';'], ['\\n'], ['F'], ['o'], ['r'], [' '], ['t'], ['h'], ['r'], ['o'], ['u'], ['g'], ['h'], [' '], ['t'], ['h'], ['i'], ['s'], [' '], ['l'], ['a'], ['u'], ['n'], ['d'], [' '], ['a'], ['n'], ['o'], ['n'], [' '], ['t'], ['h'], ['e'], [' '], ['d'], ['e'], ['e'], ['r'], [' '], ['w'], ['i'], ['l'], ['l'], [' '], ['c'], ['o'], ['m'], ['e'], [';'], ['\\n'], ['A'], ['n'], ['d'], [' '], ['i'], ['n'], [' '], ['t'], ['h'], ['i'], ['s'], [' '], ['c'], ['o'], ['v'], ['e'], ['r'], ['t'], [' '], ['w'], ['i'], ['l'], ['l'], [' '], ['w'], ['e'], [' '], ['m'], ['a'], ['k'], ['e'], [' '], ['o'], ['u'], ['r'], [' '], ['s'], ['t'], ['a'], ['n'], ['d'], [','], ['\\n'], ['C'], ['u'], ['l'], ['l'], ['i'], ['n'], ['g'], [' '], ['t'], ['h'], ['e'], [' '], ['p'], ['r'], ['i'], ['n'], ['c'], ['i'], ['p'], ['a'], ['l'], [' '], ['o'], ['f'], [' '], ['a'], ['l'], ['l'], [' '], ['t'], ['h'], ['e'], [' '], ['d'], ['e'], ['e'], ['r'], ['.'], ['\\n'], ['\\n'], ['S'], ['e'], ['c'], ['o'], ['n'], ['d'], [' '], ['K'], ['e'], ['e'], ['p'], ['e'], ['r'], [':'], ['\\n'], ['I'], [\"'\"], ['l'], ['l'], [' '], ['s'], ['t'], ['a'], ['y'], [' '], ['a'], ['b'], ['o'], ['v'], ['e'], [' '], ['t'], ['h'], ['e'], [' '], ['h'], ['i'], ['l'], ['l'], [','], [' '], ['s'], ['o'], [' '], ['b'], ['o'], ['t'], ['h'], [' '], ['m'], ['a'], ['y'], [' '], ['s'], ['h'], ['o'], ['o'], ['t'], ['.'], ['\\n'], ['\\n'], ['F'], ['i'], ['r'], ['s'], ['t'], [' '], ['K'], ['e'], ['e'], ['p'], ['e'], ['r'], [':'], ['\\n'], ['T'], ['h'], ['a'], ['t'], [' '], ['c'], ['a'], ['n'], ['n']]\n" + ] + } + ], + "source": [ + "input_tokens = x[0, :4].unsqueeze(0).cuda()\n", + "max_new_token = 8\n", + "generated_tokens = model.generate(input_tokens, max_new_token)\n", + "print('input', [enc.decode([i.item()]) for i in input_tokens[0]])\n", + "print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])\n", + "print('Gold label', [enc.decode([i.item()]) for i in x[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "steps: 0 loss: 4.718399524688721\n", + "steps: 100 loss: 4.5681352615356445\n", + "steps: 200 loss: 4.431594371795654\n", + "steps: 300 loss: 4.299322605133057\n", + "steps: 400 loss: 4.155538558959961\n", + "steps: 500 loss: 4.031984329223633\n", + "steps: 600 loss: 3.921776533126831\n", + "steps: 700 loss: 3.828066825866699\n", + "steps: 800 loss: 3.71785569190979\n", + "steps: 900 loss: 3.628220558166504\n", + "steps: 999 loss: 3.539994239807129\n" + ] + } + ], + "source": [ + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", + "batch_size = 32\n", + "context_length = 1024\n", + "iterations = 1000\n", + "for steps in range(iterations):\n", + " x, y = get_batch(batch_size, context_length, 'train')\n", + " # print(x[0], y[0])\n", + " x, y = x.cuda(), y.cuda()\n", + " loss = model.loss(x, y)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " if steps % 100 == 0:\n", + " print('steps:', steps, 'loss:', loss.item())\n", + "print('steps:', steps, 'loss:', loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input [['e'], ['s'], [' '], ['i']]\n", + "output [['e'], ['s'], [' '], ['i'], ['R'], ['Z'], [\"'\"], ['\\n'], ['d'], ['c'], [\"'\"], ['?']]\n", + "Gold label [['e'], ['s'], [' '], ['i'], ['n'], [' '], ['a'], [' '], ['m'], ['i'], ['l'], ['e'], ['-'], ['a'], ['.'], ['\\n'], ['\\n'], ['F'], ['L'], ['O'], ['R'], ['I'], ['Z'], ['E'], ['L'], [':'], ['\\n'], ['T'], ['h'], ['e'], ['s'], ['e'], [' '], ['y'], ['o'], ['u'], ['r'], [' '], ['u'], ['n'], ['u'], ['s'], ['u'], ['a'], ['l'], [' '], ['w'], ['e'], ['e'], ['d'], ['s'], [' '], ['t'], ['o'], [' '], ['e'], ['a'], ['c'], ['h'], [' '], ['p'], ['a'], ['r'], ['t'], [' '], ['o'], ['f'], [' '], ['y'], ['o'], ['u'], ['\\n'], ['D'], ['o'], [' '], ['g'], ['i'], ['v'], ['e'], [' '], ['a'], [' '], ['l'], ['i'], ['f'], ['e'], [':'], [' '], ['n'], ['o'], [' '], ['s'], ['h'], ['e'], ['p'], ['h'], ['e'], ['r'], ['d'], ['e'], ['s'], ['s'], [','], [' '], ['b'], ['u'], ['t'], [' '], ['F'], ['l'], ['o'], ['r'], ['a'], ['\\n'], ['P'], ['e'], ['e'], ['r'], ['i'], ['n'], ['g'], [' '], ['i'], ['n'], [' '], ['A'], ['p'], ['r'], ['i'], ['l'], [\"'\"], ['s'], [' '], ['f'], ['r'], ['o'], ['n'], ['t'], ['.'], [' '], ['T'], ['h'], ['i'], ['s'], [' '], ['y'], ['o'], ['u'], ['r'], [' '], ['s'], ['h'], ['e'], ['e'], ['p'], ['-'], ['s'], ['h'], ['e'], ['a'], ['r'], ['i'], ['n'], ['g'], ['\\n'], ['I'], ['s'], [' '], ['a'], ['s'], [' '], ['a'], [' '], ['m'], ['e'], ['e'], ['t'], ['i'], ['n'], ['g'], [' '], ['o'], ['f'], [' '], ['t'], ['h'], ['e'], [' '], ['p'], ['e'], ['t'], ['t'], ['y'], [' '], ['g'], ['o'], ['d'], ['s'], [','], ['\\n'], ['A'], ['n'], ['d'], [' '], ['y'], ['o'], ['u'], [' '], ['t'], ['h'], ['e'], [' '], ['q'], ['u'], ['e'], ['e'], ['n'], [' '], ['o'], ['n'], [\"'\"], ['t'], ['.'], ['\\n'], ['\\n'], ['P'], ['E'], ['R'], ['D'], ['I'], ['T'], ['A'], [':'], ['\\n'], ['S'], ['i'], ['r'], [','], [' '], ['m'], ['y'], [' '], ['g'], ['r'], ['a'], ['c'], ['i'], ['o'], ['u'], ['s'], [' '], ['l'], ['o'], ['r'], ['d'], [','], ['\\n'], ['T'], ['o'], [' '], ['c'], ['h'], ['i'], ['d'], ['e'], [' '], ['a'], ['t'], [' '], ['y'], ['o'], ['u'], ['r'], [' '], ['e'], ['x'], ['t'], ['r'], ['e'], ['m'], ['e'], ['s'], [' '], ['i'], ['t'], [' '], ['n'], ['o'], ['t'], [' '], ['b'], ['e'], ['c'], ['o'], ['m'], ['e'], ['s'], [' '], ['m'], ['e'], [':'], ['\\n'], ['O'], [','], [' '], ['p'], ['a'], ['r'], ['d'], ['o'], ['n'], [','], [' '], ['t'], ['h'], ['a'], ['t'], [' '], ['I'], [' '], ['n'], ['a'], ['m'], ['e'], [' '], ['t'], ['h'], ['e'], ['m'], ['!'], [' '], ['Y'], ['o'], ['u'], ['r'], [' '], ['h'], ['i'], ['g'], ['h'], [' '], ['s'], ['e'], ['l'], ['f'], [','], ['\\n'], ['T'], ['h'], ['e'], [' '], ['g'], ['r'], ['a'], ['c'], ['i'], ['o'], ['u'], ['s'], [' '], ['m'], ['a'], ['r'], ['k'], [' '], ['o'], [\"'\"], [' '], ['t'], ['h'], ['e'], [' '], ['l'], ['a'], ['n'], ['d'], [','], [' '], ['y'], ['o'], ['u'], [' '], ['h'], ['a'], ['v'], ['e'], [' '], ['o'], ['b'], ['s'], ['c'], ['u'], ['r'], ['e'], ['d'], ['\\n'], ['W'], ['i'], ['t'], ['h'], [' '], ['a'], [' '], ['s'], ['w'], ['a'], ['i'], ['n'], [\"'\"], ['s'], [' '], ['w'], ['e'], ['a'], ['r'], ['i'], ['n'], ['g'], [','], [' '], ['a'], ['n'], ['d'], [' '], ['m'], ['e'], [','], [' '], ['p'], ['o'], ['o'], ['r'], [' '], ['l'], ['o'], ['w'], ['l'], ['y'], [' '], ['m'], ['a'], ['i'], ['d'], [','], ['\\n'], ['M'], ['o'], ['s'], ['t'], [' '], ['g'], ['o'], ['d'], ['d'], ['e'], ['s'], ['s'], ['-'], ['l'], ['i'], ['k'], ['e'], [' '], ['p'], ['r'], ['a'], ['n'], ['k'], [\"'\"], ['d'], [' '], ['u'], ['p'], [':'], [' '], ['b'], ['u'], ['t'], [' '], ['t'], ['h'], ['a'], ['t'], [' '], ['o'], ['u'], ['r'], [' '], ['f'], ['e'], ['a'], ['s'], ['t'], ['s'], ['\\n'], ['I'], ['n'], [' '], ['e'], ['v'], ['e'], ['r'], ['y'], [' '], ['m'], ['e'], ['s'], ['s'], [' '], ['h'], ['a'], ['v'], ['e'], [' '], ['f'], ['o'], ['l'], ['l'], ['y'], [' '], ['a'], ['n'], ['d'], [' '], ['t'], ['h'], ['e'], [' '], ['f'], ['e'], ['e'], ['d'], ['e'], ['r'], ['s'], ['\\n'], ['D'], ['i'], ['g'], ['e'], ['s'], ['t'], [' '], ['i'], ['t'], [' '], ['w'], ['i'], ['t'], ['h'], [' '], ['a'], [' '], ['c'], ['u'], ['s'], ['t'], ['o'], ['m'], [','], [' '], ['I'], [' '], ['s'], ['h'], ['o'], ['u'], ['l'], ['d'], [' '], ['b'], ['l'], ['u'], ['s'], ['h'], ['\\n'], ['T'], ['o'], [' '], ['s'], ['e'], ['e'], [' '], ['y'], ['o'], ['u'], [' '], ['s'], ['o'], [' '], ['a'], ['t'], ['t'], ['i'], ['r'], ['e'], ['d'], [','], [' '], ['s'], ['w'], ['o'], ['r'], ['n'], [','], [' '], ['I'], [' '], ['t'], ['h'], ['i'], ['n'], ['k'], [','], ['\\n'], ['T'], ['o'], [' '], ['s'], ['h'], ['o'], ['w'], [' '], ['m'], ['y'], ['s'], ['e'], ['l'], ['f'], [' '], ['a'], [' '], ['g'], ['l'], ['a'], ['s'], ['s'], ['.'], ['\\n'], ['\\n'], ['F'], ['L'], ['O'], ['R'], ['I'], ['Z'], ['E'], ['L'], [':'], ['\\n'], ['I'], [' '], ['b'], ['l'], ['e'], ['s'], ['s'], [' '], ['t'], ['h'], ['e'], [' '], ['t'], ['i'], ['m'], ['e'], ['\\n'], ['W'], ['h'], ['e'], ['n'], [' '], ['m'], ['y'], [' '], ['g'], ['o'], ['o'], ['d'], [' '], ['f'], ['a'], ['l'], ['c'], ['o'], ['n'], [' '], ['m'], ['a'], ['d'], ['e'], [' '], ['h'], ['e'], ['r'], [' '], ['f'], ['l'], ['i'], ['g'], ['h'], ['t'], [' '], ['a'], ['c'], ['r'], ['o'], ['s'], ['s'], ['\\n'], ['T'], ['h'], ['y'], [' '], ['f'], ['a'], ['t'], ['h'], ['e'], ['r'], [\"'\"], ['s'], [' '], ['g'], ['r'], ['o'], ['u'], ['n'], ['d'], ['.'], ['\\n'], ['\\n'], ['P'], ['E'], ['R'], ['D'], ['I'], ['T'], ['A'], [':'], ['\\n'], ['N'], ['o'], ['w'], [' '], ['J'], ['o'], ['v'], ['e'], [' '], ['a'], ['f'], ['f'], ['o'], ['r'], ['d'], [' '], ['y'], ['o'], ['u'], [' '], ['c'], ['a'], ['u'], ['s'], ['e'], ['!'], ['\\n'], ['T'], ['o'], [' '], ['m'], ['e'], [' '], ['t'], ['h'], ['e'], [' '], ['d'], ['i'], ['f'], ['f'], ['e'], ['r'], ['e'], ['n'], ['c'], ['e'], [' '], ['f'], ['o'], ['r'], ['g'], ['e'], ['s'], [' '], ['d'], ['r'], ['e'], ['a'], ['d'], [';'], [' '], ['y'], ['o'], ['u'], ['r'], [' '], ['g'], ['r'], ['e'], ['a'], ['t'], ['n'], ['e'], ['s'], ['s'], ['\\n'], ['H'], ['a'], ['t'], ['h'], [' '], ['n'], ['o'], ['t'], [' '], ['b'], ['e'], ['e'], ['n'], [' '], ['u'], ['s'], ['e'], ['d'], [' '], ['t'], ['o'], [' '], ['f'], ['e'], ['a'], ['r'], ['.'], [' '], ['E'], ['v'], ['e'], ['n'], [' '], ['n'], ['o'], ['w'], [' '], ['I'], [' '], ['t'], ['r'], ['e'], ['m'], ['b'], ['l'], ['e'], ['\\n'], ['T'], ['o'], [' '], ['t'], ['h'], ['i'], ['n'], ['k'], [' '], ['y'], ['o'], ['u'], ['r'], [' '], ['f'], ['a'], ['t'], ['h'], ['e'], ['r'], [','], [' '], ['b'], ['y'], [' '], ['s'], ['o'], ['m'], ['e'], [' '], ['a'], ['c'], ['c'], ['i'], ['d'], ['e'], ['n'], ['t'], [','], ['\\n'], ['S'], ['h'], ['o'], ['u'], ['l'], ['d'], [' '], ['p'], ['a'], ['s'], ['s'], [' '], ['t'], ['h'], ['i'], ['s'], [' '], ['w'], ['a'], ['y'], [' '], ['a'], ['s'], [' '], ['y'], ['o'], ['u'], [' '], ['d'], ['i'], ['d'], [':'], [' '], ['O'], [','], [' '], ['t'], ['h'], ['e'], [' '], ['F'], ['a'], ['t'], ['e'], ['s'], ['!'], ['\\n'], ['H'], ['o'], ['w'], [' '], ['w'], ['o'], ['u'], ['l'], ['d'], [' '], ['h'], ['e'], [' '], ['l'], ['o'], ['o'], ['k'], [','], [' '], ['t'], ['o'], [' '], ['s'], ['e'], ['e'], [' '], ['h'], ['i'], ['s'], [' '], ['w'], ['o'], ['r'], ['k'], [' '], ['s'], ['o'], [' '], ['n'], ['o'], ['b'], ['l'], ['e'], ['\\n'], ['V'], ['i'], ['l'], ['e'], ['l'], ['y'], [' '], ['b'], ['o'], ['u'], ['n'], ['d'], [' '], ['u'], ['p'], ['?'], [' '], ['W'], ['h'], ['a'], ['t'], [' '], ['w'], ['o'], ['u'], ['l'], ['d'], [' ']]\n" + ] + } + ], + "source": [ + "input_tokens = x[0, :4].unsqueeze(0).cuda()\n", + "max_new_token = 8\n", + "generated_tokens = model.generate(input_tokens, max_new_token)\n", + "print('input', [enc.decode([i.item()]) for i in input_tokens[0]])\n", + "print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])\n", + "print('Gold label', [enc.decode([i.item()]) for i in x[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seen tokens: 32768000\n" + ] + } + ], + "source": [ + "print('seen tokens: ', batch_size * context_length * iterations)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(4.1744, device='cuda:0')\n", + "29\n", + "after 1 epoch: tensor(2.6940, device='cuda:0')\n" + ] + } + ], + "source": [ + "from ngram import Ngram\n", + "from data_char import text, enc\n", + "import torch\n", + "vocab = list(range(enc.n_vocab))\n", + "context_lengh = 16\n", + "ngram = Ngram(2, vocab)\n", + "inputs = [enc.encode(text)[:context_lengh]]\n", + "targets = torch.LongTensor([enc.encode(text)[1:context_lengh+1]]).cuda()\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "epochs = (batch_size * context_length * iterations) // len(enc.encode(text))\n", + "print(epochs)\n", + "ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print('after 1 epoch:', loss)\n", + "# for epoch in range(epochs-1):\n", + "# ngram.train(enc.encode(text))\n", + "# loss = ngram.loss(inputs, targets)\n", + "# print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1115394" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'First Citi'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n", + "{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5}\n", + "{0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f'}\n", + "tensor(4.1744, device='cuda:0')\n", + "tensor(2.9666, device='cuda:0')\n" + ] + } + ], + "source": [ + "train_text = 'abcdefabcdedfabcdedf'\n", + "check_enc = CharTokenizer(train_text)\n", + "print(check_enc.n_vocab)\n", + "print(check_enc.encoder)\n", + "print(check_enc.decoder)\n", + "ngram = Ngram(2, list(range(enc.n_vocab)))\n", + "loss = ngram.loss([check_enc.encode(train_text)[:-1]], torch.LongTensor([check_enc.encode(train_text)[1:]]).cuda())\n", + "print(loss)\n", + "ngram.train(check_enc.encode(train_text))\n", + "loss = ngram.loss([check_enc.encode(train_text)[:-1]], torch.LongTensor([check_enc.encode(train_text)[1:]]).cuda())\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 3, 5, 0, 1, 2, 3, 4, 3, 5]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "check_enc.encode(train_text)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "defaultdict(.()>,\n", + " {'0-0': 1,\n", + " '0-1': 4,\n", + " '0-2': 1,\n", + " '0-3': 1,\n", + " '0-4': 1,\n", + " '0-5': 1,\n", + " '0-6': 1,\n", + " '0-7': 1,\n", + " '0-8': 1,\n", + " '0-9': 1,\n", + " '0-10': 1,\n", + " '0-11': 1,\n", + " '0-12': 1,\n", + " '0-13': 1,\n", + " '0-14': 1,\n", + " '0-15': 1,\n", + " '0-16': 1,\n", + " '0-17': 1,\n", + " '0-18': 1,\n", + " '0-19': 1,\n", + " '0-20': 1,\n", + " '0-21': 1,\n", + " '0-22': 1,\n", + " '0-23': 1,\n", + " '0-24': 1,\n", + " '0-25': 1,\n", + " '0-26': 1,\n", + " '0-27': 1,\n", + " '0-28': 1,\n", + " '0-29': 1,\n", + " '0-30': 1,\n", + " '0-31': 1,\n", + " '0-32': 1,\n", + " '0-33': 1,\n", + " '0-34': 1,\n", + " '0-35': 1,\n", + " '0-36': 1,\n", + " '0-37': 1,\n", + " '0-38': 1,\n", + " '0-39': 1,\n", + " '0-40': 1,\n", + " '0-41': 1,\n", + " '0-42': 1,\n", + " '0-43': 1,\n", + " '0-44': 1,\n", + " '0-45': 1,\n", + " '0-46': 1,\n", + " '0-47': 1,\n", + " '0-48': 1,\n", + " '0-49': 1,\n", + " '0-50': 1,\n", + " '0-51': 1,\n", + " '0-52': 1,\n", + " '0-53': 1,\n", + " '0-54': 1,\n", + " '0-55': 1,\n", + " '0-56': 1,\n", + " '0-57': 1,\n", + " '0-58': 1,\n", + " '0-59': 1,\n", + " '0-60': 1,\n", + " '0-61': 1,\n", + " '0-62': 1,\n", + " '0-63': 1,\n", + " '0-64': 1,\n", + " '1-0': 1,\n", + " '1-1': 1,\n", + " '1-2': 4,\n", + " '1-3': 1,\n", + " '1-4': 1,\n", + " '1-5': 1,\n", + " '1-6': 1,\n", + " '1-7': 1,\n", + " '1-8': 1,\n", + " '1-9': 1,\n", + " '1-10': 1,\n", + " '1-11': 1,\n", + " '1-12': 1,\n", + " '1-13': 1,\n", + " '1-14': 1,\n", + " '1-15': 1,\n", + " '1-16': 1,\n", + " '1-17': 1,\n", + " '1-18': 1,\n", + " '1-19': 1,\n", + " '1-20': 1,\n", + " '1-21': 1,\n", + " '1-22': 1,\n", + " '1-23': 1,\n", + " '1-24': 1,\n", + " '1-25': 1,\n", + " '1-26': 1,\n", + " '1-27': 1,\n", + " '1-28': 1,\n", + " '1-29': 1,\n", + " '1-30': 1,\n", + " '1-31': 1,\n", + " '1-32': 1,\n", + " '1-33': 1,\n", + " '1-34': 1,\n", + " '1-35': 1,\n", + " '1-36': 1,\n", + " '1-37': 1,\n", + " '1-38': 1,\n", + " '1-39': 1,\n", + " '1-40': 1,\n", + " '1-41': 1,\n", + " '1-42': 1,\n", + " '1-43': 1,\n", + " '1-44': 1,\n", + " '1-45': 1,\n", + " '1-46': 1,\n", + " '1-47': 1,\n", + " '1-48': 1,\n", + " '1-49': 1,\n", + " '1-50': 1,\n", + " '1-51': 1,\n", + " '1-52': 1,\n", + " '1-53': 1,\n", + " '1-54': 1,\n", + " '1-55': 1,\n", + " '1-56': 1,\n", + " '1-57': 1,\n", + " '1-58': 1,\n", + " '1-59': 1,\n", + " '1-60': 1,\n", + " '1-61': 1,\n", + " '1-62': 1,\n", + " '1-63': 1,\n", + " '1-64': 1,\n", + " '2-0': 1,\n", + " '2-1': 1,\n", + " '2-2': 1,\n", + " '2-3': 4,\n", + " '2-4': 1,\n", + " '2-5': 1,\n", + " '2-6': 1,\n", + " '2-7': 1,\n", + " '2-8': 1,\n", + " '2-9': 1,\n", + " '2-10': 1,\n", + " '2-11': 1,\n", + " '2-12': 1,\n", + " '2-13': 1,\n", + " '2-14': 1,\n", + " '2-15': 1,\n", + " '2-16': 1,\n", + " '2-17': 1,\n", + " '2-18': 1,\n", + " '2-19': 1,\n", + " '2-20': 1,\n", + " '2-21': 1,\n", + " '2-22': 1,\n", + " '2-23': 1,\n", + " '2-24': 1,\n", + " '2-25': 1,\n", + " '2-26': 1,\n", + " '2-27': 1,\n", + " '2-28': 1,\n", + " '2-29': 1,\n", + " '2-30': 1,\n", + " '2-31': 1,\n", + " '2-32': 1,\n", + " '2-33': 1,\n", + " '2-34': 1,\n", + " '2-35': 1,\n", + " '2-36': 1,\n", + " '2-37': 1,\n", + " '2-38': 1,\n", + " '2-39': 1,\n", + " '2-40': 1,\n", + " '2-41': 1,\n", + " '2-42': 1,\n", + " '2-43': 1,\n", + " '2-44': 1,\n", + " '2-45': 1,\n", + " '2-46': 1,\n", + " '2-47': 1,\n", + " '2-48': 1,\n", + " '2-49': 1,\n", + " '2-50': 1,\n", + " '2-51': 1,\n", + " '2-52': 1,\n", + " '2-53': 1,\n", + " '2-54': 1,\n", + " '2-55': 1,\n", + " '2-56': 1,\n", + " '2-57': 1,\n", + " '2-58': 1,\n", + " '2-59': 1,\n", + " '2-60': 1,\n", + " '2-61': 1,\n", + " '2-62': 1,\n", + " '2-63': 1,\n", + " '2-64': 1,\n", + " '3-0': 1,\n", + " '3-1': 1,\n", + " '3-2': 1,\n", + " '3-3': 1,\n", + " '3-4': 4,\n", + " '3-5': 3,\n", + " '3-6': 1,\n", + " '3-7': 1,\n", + " '3-8': 1,\n", + " '3-9': 1,\n", + " '3-10': 1,\n", + " '3-11': 1,\n", + " '3-12': 1,\n", + " '3-13': 1,\n", + " '3-14': 1,\n", + " '3-15': 1,\n", + " '3-16': 1,\n", + " '3-17': 1,\n", + " '3-18': 1,\n", + " '3-19': 1,\n", + " '3-20': 1,\n", + " '3-21': 1,\n", + " '3-22': 1,\n", + " '3-23': 1,\n", + " '3-24': 1,\n", + " '3-25': 1,\n", + " '3-26': 1,\n", + " '3-27': 1,\n", + " '3-28': 1,\n", + " '3-29': 1,\n", + " '3-30': 1,\n", + " '3-31': 1,\n", + " '3-32': 1,\n", + " '3-33': 1,\n", + " '3-34': 1,\n", + " '3-35': 1,\n", + " '3-36': 1,\n", + " '3-37': 1,\n", + " '3-38': 1,\n", + " '3-39': 1,\n", + " '3-40': 1,\n", + " '3-41': 1,\n", + " '3-42': 1,\n", + " '3-43': 1,\n", + " '3-44': 1,\n", + " '3-45': 1,\n", + " '3-46': 1,\n", + " '3-47': 1,\n", + " '3-48': 1,\n", + " '3-49': 1,\n", + " '3-50': 1,\n", + " '3-51': 1,\n", + " '3-52': 1,\n", + " '3-53': 1,\n", + " '3-54': 1,\n", + " '3-55': 1,\n", + " '3-56': 1,\n", + " '3-57': 1,\n", + " '3-58': 1,\n", + " '3-59': 1,\n", + " '3-60': 1,\n", + " '3-61': 1,\n", + " '3-62': 1,\n", + " '3-63': 1,\n", + " '3-64': 1,\n", + " '4-0': 1,\n", + " '4-1': 1,\n", + " '4-2': 1,\n", + " '4-3': 3,\n", + " '4-4': 1,\n", + " '4-5': 2,\n", + " '4-6': 1,\n", + " '4-7': 1,\n", + " '4-8': 1,\n", + " '4-9': 1,\n", + " '4-10': 1,\n", + " '4-11': 1,\n", + " '4-12': 1,\n", + " '4-13': 1,\n", + " '4-14': 1,\n", + " '4-15': 1,\n", + " '4-16': 1,\n", + " '4-17': 1,\n", + " '4-18': 1,\n", + " '4-19': 1,\n", + " '4-20': 1,\n", + " '4-21': 1,\n", + " '4-22': 1,\n", + " '4-23': 1,\n", + " '4-24': 1,\n", + " '4-25': 1,\n", + " '4-26': 1,\n", + " '4-27': 1,\n", + " '4-28': 1,\n", + " '4-29': 1,\n", + " '4-30': 1,\n", + " '4-31': 1,\n", + " '4-32': 1,\n", + " '4-33': 1,\n", + " '4-34': 1,\n", + " '4-35': 1,\n", + " '4-36': 1,\n", + " '4-37': 1,\n", + " '4-38': 1,\n", + " '4-39': 1,\n", + " '4-40': 1,\n", + " '4-41': 1,\n", + " '4-42': 1,\n", + " '4-43': 1,\n", + " '4-44': 1,\n", + " '4-45': 1,\n", + " '4-46': 1,\n", + " '4-47': 1,\n", + " '4-48': 1,\n", + " '4-49': 1,\n", + " '4-50': 1,\n", + " '4-51': 1,\n", + " '4-52': 1,\n", + " '4-53': 1,\n", + " '4-54': 1,\n", + " '4-55': 1,\n", + " '4-56': 1,\n", + " '4-57': 1,\n", + " '4-58': 1,\n", + " '4-59': 1,\n", + " '4-60': 1,\n", + " '4-61': 1,\n", + " '4-62': 1,\n", + " '4-63': 1,\n", + " '4-64': 1,\n", + " '5-0': 3,\n", + " '5-1': 1,\n", + " '5-2': 1,\n", + " '5-3': 1,\n", + " '5-4': 1,\n", + " '5-5': 1,\n", + " '5-6': 1,\n", + " '5-7': 1,\n", + " '5-8': 1,\n", + " '5-9': 1,\n", + " '5-10': 1,\n", + " '5-11': 1,\n", + " '5-12': 1,\n", + " '5-13': 1,\n", + " '5-14': 1,\n", + " '5-15': 1,\n", + " '5-16': 1,\n", + " '5-17': 1,\n", + " '5-18': 1,\n", + " '5-19': 1,\n", + " '5-20': 1,\n", + " '5-21': 1,\n", + " '5-22': 1,\n", + " '5-23': 1,\n", + " '5-24': 1,\n", + " '5-25': 1,\n", + " '5-26': 1,\n", + " '5-27': 1,\n", + " '5-28': 1,\n", + " '5-29': 1,\n", + " '5-30': 1,\n", + " '5-31': 1,\n", + " '5-32': 1,\n", + " '5-33': 1,\n", + " '5-34': 1,\n", + " '5-35': 1,\n", + " '5-36': 1,\n", + " '5-37': 1,\n", + " '5-38': 1,\n", + " '5-39': 1,\n", + " '5-40': 1,\n", + " '5-41': 1,\n", + " '5-42': 1,\n", + " '5-43': 1,\n", + " '5-44': 1,\n", + " '5-45': 1,\n", + " '5-46': 1,\n", + " '5-47': 1,\n", + " '5-48': 1,\n", + " '5-49': 1,\n", + " '5-50': 1,\n", + " '5-51': 1,\n", + " '5-52': 1,\n", + " '5-53': 1,\n", + " '5-54': 1,\n", + " '5-55': 1,\n", + " '5-56': 1,\n", + " '5-57': 1,\n", + " '5-58': 1,\n", + " '5-59': 1,\n", + " '5-60': 1,\n", + " '5-61': 1,\n", + " '5-62': 1,\n", + " '5-63': 1,\n", + " '5-64': 1})" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ngram.ngram" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(4.1744, device='cuda:0')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(2.6793, device='cuda:0')\n" + ] + } + ], + "source": [ + "ngram = Ngram(2, vocab, 1e-3)\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(4.1744, device='cuda:0')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(1.2322, device='cuda:0')\n" + ] + } + ], + "source": [ + "ngram = Ngram(4, vocab, 1e-3)\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)\n", + "ngram.train(enc.encode(text))\n", + "loss = ngram.loss(inputs, targets)\n", + "print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# ngram.ngram" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playground/GPT DEV_3_attention_check_overfit.ipynb b/playground/GPT DEV_3_attention_check_overfit.ipynb new file mode 100644 index 0000000..c633d3b --- /dev/null +++ b/playground/GPT DEV_3_attention_check_overfit.ipynb @@ -0,0 +1,989 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[15., 57., 92.],\n", + " [ 0., 95., 53.],\n", + " [15., 10., 34.],\n", + " [90., 12., 20.]],\n", + "\n", + " [[97., 86., 90.],\n", + " [38., 51., 64.],\n", + " [ 9., 15., 13.],\n", + " [46., 22., 50.]]])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "torch.manual_seed(1337)\n", + "\n", + "batch_size = 2\n", + "sequence_length = 4\n", + "d_model = 3\n", + "representations = torch.randint(0, 100, (batch_size, sequence_length, d_model)).float()\n", + "representations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## generilize a special case\n", + "> This is a bit more of a general coding tip but I’ve often seen people create bugs when they bite off more than they can chew, writing a relatively general functionality from scratch. I like to write a very specific function to what I’m doing right now, get that to work, and then generalize it later making sure that I get the same result. Often this applies to vectorizing code, where I almost always write out the fully loopy version first and only then transform it to vectorized code one loop at a time.\n", + "\n", + "https://karpathy.github.io/2019/04/25/recipe/" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[15.0000, 57.0000, 92.0000],\n", + " [ 7.5000, 76.0000, 72.5000],\n", + " [10.0000, 54.0000, 59.6667],\n", + " [30.0000, 43.5000, 49.7500]],\n", + "\n", + " [[97.0000, 86.0000, 90.0000],\n", + " [67.5000, 68.5000, 77.0000],\n", + " [48.0000, 50.6667, 55.6667],\n", + " [47.5000, 43.5000, 54.2500]]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "aggregated_representations = torch.empty((batch_size, sequence_length, d_model))\n", + "\n", + "for batch_idx in range(batch_size):\n", + " for sequence_idx in range(sequence_length):\n", + " aggregated_representations[batch_idx, sequence_idx] = torch.mean(representations[batch_idx, :sequence_idx+1], dim=0)\n", + "aggregated_representations" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", + " [0.5000, 0.5000, 0.0000, 0.0000],\n", + " [0.3333, 0.3333, 0.3333, 0.0000],\n", + " [0.2500, 0.2500, 0.2500, 0.2500]],\n", + "\n", + " [[1.0000, 0.0000, 0.0000, 0.0000],\n", + " [0.5000, 0.5000, 0.0000, 0.0000],\n", + " [0.3333, 0.3333, 0.3333, 0.0000],\n", + " [0.2500, 0.2500, 0.2500, 0.2500]]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_score = torch.tril(torch.ones((batch_size, sequence_length, sequence_length)))\n", + "attention_score = attention_score / torch.sum(attention_score, dim=2, keepdim=True)\n", + "attention_score" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[15.0000, 57.0000, 92.0000],\n", + " [ 7.5000, 76.0000, 72.5000],\n", + " [10.0000, 54.0000, 59.6667],\n", + " [30.0000, 43.5000, 49.7500]],\n", + "\n", + " [[97.0000, 86.0000, 90.0000],\n", + " [67.5000, 68.5000, 77.0000],\n", + " [48.0000, 50.6667, 55.6667],\n", + " [47.5000, 43.5000, 54.2500]]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "attention_score @ representations" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.0476, -inf, -inf, -inf],\n", + " [-1.1081, -1.8002, -inf, -inf],\n", + " [ 0.1662, 1.2055, 0.1883, -inf],\n", + " [-0.1585, -0.6300, -0.2221, 0.6924]],\n", + "\n", + " [[ 1.1490, -inf, -inf, -inf],\n", + " [ 0.1526, 0.3843, -inf, -inf],\n", + " [-0.7296, -1.5580, -0.3950, -inf],\n", + " [-1.7097, -0.0826, -0.0495, -1.4480]]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tril = torch.tril(torch.ones((sequence_length, sequence_length)))\n", + "qk_dot_product = torch.randn((batch_size, sequence_length, sequence_length))\n", + "qk_dot_product = qk_dot_product.masked_fill(tril == 0, float('-inf'))\n", + "qk_dot_product" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", + " [0.6664, 0.3336, 0.0000, 0.0000],\n", + " [0.2062, 0.5830, 0.2108, 0.0000],\n", + " [0.2039, 0.1273, 0.1913, 0.4775]],\n", + "\n", + " [[1.0000, 0.0000, 0.0000, 0.0000],\n", + " [0.4424, 0.5576, 0.0000, 0.0000],\n", + " [0.3528, 0.1541, 0.4931, 0.0000],\n", + " [0.0791, 0.4024, 0.4159, 0.1027]]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.softmax(qk_dot_product, dim=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "d_head = 16\n", + "k = torch.randn(batch_size, sequence_length, d_head)\n", + "q = torch.randn(batch_size, sequence_length, d_head)\n", + "v = torch.randn(batch_size, sequence_length, d_head)\n", + "\n", + "qk_dot_product = q @ k.transpose(-2, -1)\n", + "scaled_qk_dot_product = qk_dot_product / (d_head ** 0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.8750)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "k.var()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.1169)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "q.var()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor(13.3547), tensor(0.8347))" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qk_dot_product.var(), scaled_qk_dot_product.var()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])* 8, dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.nn import functional as F\n", + "\n", + "torch.manual_seed(1337)\n", + "\n", + "class Head(nn.Module):\n", + " def __init__(self, d_model, d_head):\n", + " super().__init__()\n", + " self.key = nn.Linear(d_model, d_head, bias=False)\n", + " self.query = nn.Linear(d_model, d_head, bias=False)\n", + " self.value = nn.Linear(d_model, d_head, bias=False)\n", + " self.back_to_d_model = nn.Linear(d_head, d_model)\n", + "\n", + " self.register_buffer('mask', torch.tril(torch.ones((sequence_length, sequence_length))))\n", + " \n", + " \n", + " def forward(self, x):\n", + " # x: (batch_size, sequence_length, d_model)\n", + " k = self.key(x)\n", + " q = self.query(x)\n", + " v = self.value(x)\n", + "\n", + " qk_dot_product = q @ k.transpose(-2, -1) / (d_head ** 0.5)\n", + " qk_dot_product = qk_dot_product.masked_fill(self.mask == 0, float('-inf'))\n", + " attention_score = torch.softmax(qk_dot_product, dim=-1)\n", + " out = attention_score @ v\n", + " out = self.back_to_d_model(out)\n", + " return out\n", + "\n", + "\n", + "class AttentionLM(nn.Module):\n", + " def __init__(self, vocab_size, sequence_length, d_model, d_head):\n", + " super().__init__()\n", + " self.embed = nn.Embedding(vocab_size, d_model)\n", + " self.pos_embed = nn.Embedding(sequence_length, d_model)\n", + " self.head = Head(d_model, d_head)\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.unembed = nn.Linear(d_model, vocab_size)\n", + " print('number of parameters:', sum(p.numel() for p in self.parameters()))\n", + " \n", + " \n", + " def forward(self, token_indexes):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.size()\n", + " token_embed = self.embed(token_indexes)\n", + " pos_embed = self.pos_embed(torch.arange(sequence_length).to(token_embed.device))\n", + " x = token_embed + pos_embed\n", + " x = self.head(x)\n", + " x = self.ln(x)\n", + " logits = self.unembed(x)\n", + "\n", + " return logits\n", + " \n", + " def loss_per_token(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length),\n", + " reduction='none'\n", + " )\n", + " # loss: (batch_size*sequence_length)\n", + " return loss.view(batch_size, sequence_length)\n", + " \n", + " def loss(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length)\n", + " )\n", + " # loss: scalar\n", + " return loss\n", + " \n", + " def generate(self, token_indexes, max_new_tokens):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.shape\n", + " for _ in range(max_new_tokens):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " next_token_logits = logits[:, -1, :]\n", + " # next_token_logits: (batch_size, vocab_size)\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " # next_token_probs: (batch_size, vocab_size)\n", + " next_token = torch.multinomial(next_token_probs, num_samples=1)\n", + " # next_token: (batch_size, 1)\n", + " token_indexes = torch.cat([token_indexes, next_token], dim=1)\n", + " # token_indexes: (batch_size, sequence_length+1)\n", + " return token_indexes\n", + "\n", + "\n", + "class BigramLanguageModel(nn.Module):\n", + " def __init__(self, vocab_size, d_model):\n", + " super().__init__()\n", + " # self.bigram_table = nn.Embedding(vocab_size, vocab_size)\n", + " self.token_embedding_table = nn.Embedding(vocab_size, d_model)\n", + " self.ln = nn.LayerNorm(d_model)\n", + " self.linear = nn.Linear(d_model, vocab_size)\n", + " print('number of parameters:', sum(p.numel() for p in self.parameters()))\n", + " \n", + " def forward(self, token_indexes):\n", + " # token_index: (batch_size, sequence_length)\n", + " # logits = self.bigram_table(token_indexes)\n", + "\n", + " embedding = self.token_embedding_table(token_indexes)\n", + " embedding = self.ln(embedding)\n", + " logits = self.linear(embedding)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " return logits\n", + "\n", + " def loss_per_token(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length),\n", + " reduction='none'\n", + " )\n", + " # loss: (batch_size*sequence_length)\n", + " return loss.view(batch_size, sequence_length)\n", + " \n", + " \n", + " def loss(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length)\n", + " )\n", + " # loss: scalar\n", + " return loss\n", + " \n", + " def generate(self, token_indexes, max_new_tokens):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.shape\n", + " for _ in range(max_new_tokens):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " next_token_logits = logits[:, -1, :]\n", + " # next_token_logits: (batch_size, vocab_size)\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " # next_token_probs: (batch_size, vocab_size)\n", + " next_token = torch.multinomial(next_token_probs, num_samples=1)\n", + " # next_token: (batch_size, 1)\n", + " token_indexes = torch.cat([token_indexes, next_token], dim=1)\n", + " # token_indexes: (batch_size, sequence_length+1)\n", + " return token_indexes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 1085249\n", + "steps: 0 loss: 4.328359603881836\n", + "steps: 100 loss: 1.9020140171051025\n", + "steps: 200 loss: 0.9388523697853088\n", + "steps: 299 loss: 0.45989322662353516\n", + "validation loss: 8.844975471496582\n" + ] + } + ], + "source": [ + "from data_char import enc, get_batch\n", + "vocab_size = enc.n_vocab\n", + "sequence_length = 1024\n", + "d_model = 768\n", + "d_head = 64\n", + "\n", + "model = AttentionLM(vocab_size, sequence_length, d_model, d_head).cuda()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", + "batch_size = 32\n", + "context_length = 1024\n", + "iterations = 300\n", + "x, y = get_batch(batch_size, context_length, 'train')\n", + "for steps in range(iterations):\n", + " # print(x[0], y[0])\n", + " x, y = x.cuda(), y.cuda()\n", + " loss = model.loss(x, y)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " if steps % 100 == 0:\n", + " print('steps:', steps, 'loss:', loss.item())\n", + " # break\n", + "print('steps:', steps, 'loss:', loss.item())\n", + "\n", + "with torch.no_grad():\n", + " val_x, val_y = get_batch(1, context_length, 'val')\n", + " val_x, val_y = val_x.cuda(), val_y.cuda()\n", + " loss = model.loss(val_x, val_y)\n", + " print('validation loss:', loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'loss per token')" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "loss_per_token = model.loss_per_token(x,y)\n", + "loss = model.loss(x,y)\n", + "loss_per_token.mean().item(), loss.item()\n", + "# print(loss_per_token.shape)\n", + "\n", + "# plot by points\n", + "plt.plot(loss_per_token.mean(dim=0).detach().cpu().numpy())\n", + "# x-axis 0 ~ 1024\n", + "plt.xlim(0, 1024)\n", + "plt.ylim(0, 5)\n", + "plt.xlabel('token index')\n", + "plt.ylabel('loss')\n", + "plt.title('loss per token')" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 101441\n", + "steps: 0 loss: 4.318515300750732\n", + "steps: 100 loss: 2.431689500808716\n", + "steps: 200 loss: 2.423933744430542\n", + "steps: 299 loss: 2.4220685958862305\n", + "validation loss: 2.4799563884735107\n" + ] + } + ], + "source": [ + "model = BigramLanguageModel(vocab_size, d_model).cuda()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n", + "batch_size = 32\n", + "context_length = 1024\n", + "iterations = 300\n", + "x, y = get_batch(batch_size, context_length, 'train')\n", + "\n", + "for steps in range(iterations):\n", + " # print(x[0], y[0])\n", + " x, y = x.cuda(), y.cuda()\n", + " loss = model.loss(x, y)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " if steps % 100 == 0:\n", + " print('steps:', steps, 'loss:', loss.item())\n", + "print('steps:', steps, 'loss:', loss.item())\n", + "with torch.no_grad():\n", + " val_x, val_y = get_batch(1, context_length, 'val')\n", + " val_x, val_y = val_x.cuda(), val_y.cuda()\n", + " loss = model.loss(val_x, val_y)\n", + " print('validation loss:', loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'loss per token')" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "loss_per_token = model.loss_per_token(x,y)\n", + "loss = model.loss(x,y)\n", + "loss_per_token.mean().item(), loss.item()\n", + "# print(loss_per_token.shape)\n", + "\n", + "plt.plot(loss_per_token.mean(dim=0).detach().cpu().numpy())\n", + "plt.xlim(0, 1024)\n", + "plt.ylim(0, 5)\n", + "plt.xlabel('token index')\n", + "plt.ylabel('loss')\n", + "plt.title('loss per token')" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.174387269895637" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import math\n", + "-math.log(1/vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "class Head(nn.Module):\n", + " \"\"\" one head of self-attention \"\"\"\n", + "\n", + " def __init__(self, d_model, d_head):\n", + " super().__init__()\n", + " self.key = nn.Linear(d_model, d_head, bias=False)\n", + " self.query = nn.Linear(d_model, d_head, bias=False)\n", + " self.value = nn.Linear(d_model, d_head, bias=False)\n", + " self.register_buffer('tril', torch.tril(torch.ones(sequence_length, sequence_length)))\n", + "\n", + " # self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " B,T,C = x.shape\n", + " k = self.key(x) # (B,T,C)\n", + " q = self.query(x) # (B,T,C)\n", + " # compute attention scores (\"affinities\")\n", + " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", + " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", + " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", + " # wei = self.dropout(wei)\n", + " # perform the weighted aggregation of the values\n", + " v = self.value(x) # (B,T,C)\n", + " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", + " return out\n", + "\n", + "class MultiHeadAttention(nn.Module):\n", + " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", + "\n", + " def __init__(self, num_heads, d_model, d_head):\n", + " super().__init__()\n", + " self.heads = nn.ModuleList([Head(d_model, d_head) for _ in range(num_heads)])\n", + " self.proj = nn.Linear(d_model, d_model)\n", + " # self.dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", + " # out = self.dropout(self.proj(out))\n", + " out = self.proj(out)\n", + " return out\n", + "\n", + "class FeedFoward(nn.Module):\n", + " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", + "\n", + " def __init__(self, d_model):\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(d_model, 4 * d_model),\n", + " nn.ReLU(),\n", + " nn.Linear(4 * d_model, d_model),\n", + " # nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "class Block(nn.Module):\n", + " \"\"\" Transformer block: communication followed by computation \"\"\"\n", + "\n", + " def __init__(self, d_model, d_head, num_heads):\n", + " # d_model: embedding dimension, num_heads: the number of heads we'd like\n", + " super().__init__()\n", + " d_head = d_model // num_heads\n", + " self.sa = MultiHeadAttention(num_heads, d_model, d_head)\n", + " self.ffwd = FeedFoward(d_model)\n", + " self.ln1 = nn.LayerNorm(d_model)\n", + " self.ln2 = nn.LayerNorm(d_model)\n", + "\n", + " def forward(self, x):\n", + " x = x + self.sa(self.ln1(x))\n", + " x = x + self.ffwd(self.ln2(x))\n", + " return x\n", + "\n", + "# super simple bigram model\n", + "class AttentionLM(nn.Module):\n", + "\n", + " def __init__(self, vocab_size, sequence_length, d_model, d_head, num_heads, n_layer):\n", + " super().__init__()\n", + " # each token directly reads off the logits for the next token from a lookup table\n", + " self.token_embedding_table = nn.Embedding(vocab_size, d_model)\n", + " self.position_embedding_table = nn.Embedding(sequence_length, d_model)\n", + " self.blocks = nn.Sequential(*[Block(d_model, d_head=d_head, num_heads=num_heads) for _ in range(n_layer)])\n", + " self.ln_f = nn.LayerNorm(d_model) # final layer norm\n", + " self.lm_head = nn.Linear(d_model, vocab_size)\n", + " print('number of parameters:', sum(p.numel() for p in self.parameters()))\n", + " \n", + " def forward(self, idx):\n", + " B, T = idx.shape\n", + "\n", + " # idx and targets are both (B,T) tensor of integers\n", + " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", + " pos_emb = self.position_embedding_table(torch.arange(T, device=tok_emb.device)) # (T,C)\n", + " x = tok_emb + pos_emb # (B,T,C)\n", + " x = self.blocks(x) # (B,T,C)\n", + " x = self.ln_f(x) # (B,T,C)\n", + " logits = self.lm_head(x) # (B,T,vocab_size)\n", + " return logits\n", + "\n", + " def loss(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length)\n", + " )\n", + " # loss: scalar\n", + " return loss\n", + " \n", + " def generate(self, token_indexes, max_new_tokens):\n", + " # token_indexes: (batch_size, sequence_length)\n", + " batch_size, sequence_length = token_indexes.shape\n", + " for _ in range(max_new_tokens):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " next_token_logits = logits[:, -1, :]\n", + " # next_token_logits: (batch_size, vocab_size)\n", + " next_token_probs = F.softmax(next_token_logits, dim=-1)\n", + " # next_token_probs: (batch_size, vocab_size)\n", + " next_token = torch.multinomial(next_token_probs, num_samples=1)\n", + " # next_token: (batch_size, 1)\n", + " token_indexes = torch.cat([token_indexes, next_token], dim=1)\n", + " # token_indexes: (batch_size, sequence_length+1)\n", + " return token_indexes\n", + " \n", + " def loss_per_token(self, token_indexes, targets):\n", + " logits = self(token_indexes)\n", + " # logits: (batch_size, sequence_length, vocab_size)\n", + " # targets: (batch_size, sequence_length)\n", + " batch_size, sequence_length, vocab_size = logits.shape\n", + " loss = F.cross_entropy(\n", + " logits.view(batch_size*sequence_length, vocab_size),\n", + " targets.view(batch_size*sequence_length),\n", + " reduction='none'\n", + " )\n", + " # loss: (batch_size*sequence_length)\n", + " return loss.view(batch_size, sequence_length)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overfit one batch\n", + "> Overfit a single batch of only a few examples (e.g. as little as two). To do so we increase the capacity of our model (e.g. add layers or filters) and verify that we can reach the lowest achievable loss (e.g. zero). I also like to visualize in the same plot both the label and the prediction and ensure that they end up aligning perfectly once we reach the minimum loss. If they do not, there is a bug somewhere and we cannot continue to the next stage.\n", + "\n", + "> The approach I like to take to finding a good model has two stages: first get a model large enough that it can overfit (i.e. focus on training loss) and then regularize it appropriately (give up some training loss to improve the validation loss). The reason I like these two stages is that if we are not able to reach a low error rate with any model at all that may again indicate some issues, bugs, or misconfiguration.\n", + "\n", + "https://karpathy.github.io/2019/04/25/recipe/" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of parameters: 15059009\n", + "steps: 0 loss: 4.373307228088379\n", + "steps: 100 loss: 1.3534984588623047\n", + "steps: 200 loss: 0.14664000272750854\n", + "steps: 300 loss: 0.007650444284081459\n", + "steps: 400 loss: 0.003188813803717494\n", + "steps: 500 loss: 0.0019986487459391356\n", + "steps: 600 loss: 0.0014852412277832627\n", + "steps: 700 loss: 0.0012111642863601446\n", + "steps: 800 loss: 0.0010455630254000425\n", + "steps: 900 loss: 0.0009370149346068501\n", + "steps: 999 loss: 0.0008622257155366242\n", + "validation loss: 11.710304260253906\n" + ] + } + ], + "source": [ + "from data_char import enc, get_batch\n", + "vocab_size = enc.n_vocab\n", + "sequence_length = 1024\n", + "d_model = 768\n", + "d_head = 64\n", + "n_layer = 2\n", + "num_heads = 12\n", + "\n", + "model = AttentionLM(vocab_size, sequence_length, d_model, d_head, num_heads, n_layer).cuda()\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)\n", + "batch_size = 32\n", + "context_length = 1024\n", + "iterations = 1000\n", + "x, y = get_batch(batch_size, context_length, 'train')\n", + "for steps in range(iterations):\n", + " # print(x[0], y[0])\n", + " x, y = x.cuda(), y.cuda()\n", + " loss = model.loss(x, y)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " if steps % 100 == 0:\n", + " print('steps:', steps, 'loss:', loss.item())\n", + "print('steps:', steps, 'loss:', loss.item())\n", + "\n", + "with torch.no_grad():\n", + " val_x, val_y = get_batch(1, context_length, 'val')\n", + " val_x, val_y = val_x.cuda(), val_y.cuda()\n", + " loss = model.loss(val_x, val_y)\n", + " print('validation loss:', loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0000499835818495" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "math.exp(4.998233271180652e-05)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'loss per token')" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "loss_per_token = model.loss_per_token(x,y)\n", + "loss = model.loss(x,y)\n", + "loss_per_token.mean().item(), loss.item()\n", + "# print(loss_per_token.shape)\n", + "\n", + "plt.plot(loss_per_token.mean(dim=0).detach().cpu().numpy())\n", + "plt.xlim(0, 1024)\n", + "plt.ylim(-0.1, 5)\n", + "plt.xlabel('token index')\n", + "plt.ylabel('loss')\n", + "plt.title('loss per token')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References\n", + "- https://www.youtube.com/watch?v=kCc8FmEb1nY \n", + "- https://github.com/karpathy/ng-video-lecture \n", + "- https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playground/data.py b/playground/data.py new file mode 100644 index 0000000..d6b43ab --- /dev/null +++ b/playground/data.py @@ -0,0 +1,21 @@ +import torch +import tiktoken + +with open('input.txt', 'r', encoding='utf-8') as f: + text = f.read() + +enc = tiktoken.get_encoding("gpt2") +seed = 1337 +torch.manual_seed(seed) +data = torch.tensor(enc.encode(text), dtype=torch.long) +n = int(0.9*len(data)) # first 90% will be train, rest val +train_data = data[:n] +val_data = data[n:] + + +def get_batch(batch_size, context_length, split='train'): + data = train_data if split == 'train' else val_data + index = torch.randint(len(data) - context_length, (batch_size,)) + x = torch.stack([data[i:i+context_length] for i in index]) + y = torch.stack([data[i+1:i+1+context_length] for i in index]) + return x, y \ No newline at end of file diff --git a/playground/data_char.py b/playground/data_char.py new file mode 100644 index 0000000..2baccfd --- /dev/null +++ b/playground/data_char.py @@ -0,0 +1,33 @@ +import torch + +with open('input.txt', 'r', encoding='utf-8') as f: + text = f.read() + +class CharTokenizer: + def __init__(self, text): + self.vocab = sorted(list(set(text))) + self.n_vocab = len(self.vocab) + self.encoder = {k: v for v, k in enumerate(self.vocab)} + self.decoder = {v: k for k, v in self.encoder.items()} + + def encode(self, text): + return [self.encoder[c] for c in text] + + def decode(self, tokens): + return [self.decoder[t] for t in tokens] + +seed = 1337 +torch.manual_seed(seed) +enc = CharTokenizer(text) +data = torch.tensor(enc.encode(text), dtype=torch.long) +n = int(0.9*len(data)) # first 90% will be train, rest val +train_data = data[:n] +val_data = data[n:] + + +def get_batch(batch_size, context_length, split='train'): + data = train_data if split == 'train' else val_data + index = torch.randint(len(data) - context_length, (batch_size,)) + x = torch.stack([data[i:i+context_length] for i in index]) + y = torch.stack([data[i+1:i+1+context_length] for i in index]) + return x, y \ No newline at end of file diff --git a/playground/model.py b/playground/model.py new file mode 100644 index 0000000..3848bb1 --- /dev/null +++ b/playground/model.py @@ -0,0 +1,5 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F + +torch.manual_seed(1337) \ No newline at end of file diff --git a/playground/ngram.py b/playground/ngram.py new file mode 100644 index 0000000..0bb9df3 --- /dev/null +++ b/playground/ngram.py @@ -0,0 +1,149 @@ +from collections import defaultdict +import torch +from torch.nn import functional as F +import math +import copy + + +class Ngram: + def __init__(self, n, vocab, laplace=1): + self.n = n + self.vocab = vocab + self.laplace = laplace + self.ngram = defaultdict(lambda: laplace) + self.context_count = defaultdict(lambda: laplace * len(self.vocab)) + + def train(self, token_list): + assert isinstance(token_list, list) + for i in range(len(token_list) - self.n + 1): + ngram_list = copy.deepcopy(token_list[i:i+self.n]) + ngram_list = [str(i) for i in ngram_list] + context = ngram_list[:-1] + ngram_key = '-'.join(ngram_list) + context_key = '-'.join(context) + self.ngram[ngram_key] += 1 + self.context_count[context_key] += 1 + # print(ngram_key, context_key) + + + def train_batch(self, token_list): + for tokens in token_list: + self.train(tokens) + + def get_prob(self, ngram): + if self.n == 1: + return self.ngram[ngram] / len(self.vocab) + else: + context = ngram.split('-')[:-1] + context = '-'.join(context) + # if self.context_count[context] == 0: + # return 1 / len(self.vocab) + # else: + # if self.ngram[ngram] == 0: + # return 1e-20 + # return self.ngram[ngram] / self.context_count[context] + return self.ngram[ngram] / self.context_count[context] + + def get_prob_distribution(self, n_minus_1_gram): + distribution = [] + distribution_dict = {} + for word in self.vocab: + ngram_list = n_minus_1_gram + [word] + ngram = '-'.join([str(i) for i in ngram_list]) + # print('hi', ngram) + distribution.append(self.get_prob(ngram)) + distribution_dict[word] = self.get_prob(ngram) + return distribution, distribution_dict + + def forward(self, token_indexes): + # token_index: (batch_size, sequence_length) + if isinstance(token_indexes, torch.Tensor) or isinstance(token_indexes, torch.LongTensor): + token_indexes = token_indexes.tolist() + assert isinstance(token_indexes, list) + batch_size = len(token_indexes) + sequence_length = len(token_indexes[0]) + distributions = torch.ones(batch_size, sequence_length, len(self.vocab)) + distributions /= len(self.vocab) + for i in range(sequence_length): + for batch in range(batch_size): + if self.n == 2: + context = [token_indexes[batch][i]] + else: + if i < self.n - 1: + if i == 0: + context = [token_indexes[batch][i]] + else: + context = token_indexes[batch][:i+1] + else: + context = token_indexes[batch][i-self.n+2:i+1] + distribution, _ = self.get_prob_distribution(context) + distributions[batch, i] = torch.tensor(distribution) + # distributions: (batch_size, sequence_length, vocab_size) + return distributions + + def loss(self, token_indexes, targets): + # token_indexes: (batch_size, sequence_length) + # targets: (batch_size, sequence_length) + distributions = self.forward(token_indexes) + distributions = distributions.to(targets.device) + log_distributions = torch.log(distributions) + # print(log_distributions) + # targets: (batch_size, sequence_length) + batch_size, sequence_length, vocab_size = log_distributions.shape + loss = F.nll_loss( + log_distributions.view(batch_size*sequence_length, vocab_size), + targets.view(batch_size*sequence_length) + ) + # loss: scalar + return loss + + + +if __name__ == "__main__": + vocab_str = ["I", "am", "an", "NLPer", "a", "student", "in", "Tokyo", "University"] + tokenizer = { + "I": 0, + "am": 1, + "an": 2, + "NLPer": 3, + "a": 4, + "student": 5, + "in": 6, + "Tokyo": 7, + "University": 8 + } + decoder = {v: k for k, v in tokenizer.items()} + text = "I am an NLPer" + words = text.split() + words_token = [tokenizer[word] for word in words] + ngram = Ngram(2, tokenizer.values()) + # distribution, distribution_dict = ngram.get_prob_distribution((tokenizer["I"],)) + # print(distribution) + # print(distribution_dict) + x, y = words[:-1], words[1:] + x = [[tokenizer[word] for word in x]] + y = torch.tensor([[tokenizer[word] for word in y]]) + print(x, y) + loss = ngram.loss(x, y) + print(loss) + + ngram.train(words_token) + distribution, distribution_dict = ngram.get_prob_distribution((tokenizer["I"],)) + print(distribution) + print(distribution_dict) + loss = ngram.loss(x, y) + print(loss) + + ngram = Ngram(2, tokenizer.values(), 1e-5) + ngram.train(words_token) + loss = ngram.loss(x, y) + print(loss) + + # ngram = Ngram(2, tokenizer.values()) + # for epoch in range(100): + # ngram.train(words_token) + # loss = ngram.loss(x, y) + # print('Epoch: {}, Loss: {}'.format(epoch, loss)) + # distribution, distribution_dict = ngram.get_prob_distribution((tokenizer["I"],)) + # print(distribution) + # print(distribution_dict) \ No newline at end of file