diff --git a/precondition/datamix_gemma/Pretokenizing_wikipedia.ipynb b/precondition/datamix_gemma/Pretokenizing_wikipedia.ipynb new file mode 100644 index 0000000..6cf2943 --- /dev/null +++ b/precondition/datamix_gemma/Pretokenizing_wikipedia.ipynb @@ -0,0 +1,409 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ybS4pSQElymu" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}\n", + "# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", + "ckpt_path = '/g_mini/2b_it_v1p1_orbax/1'\n", + "vocab_path = 'home/mriviere/g_mini/tokenizer/gemini_bpe_256k_v5_no_tags_cleared_v1.model'" + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Python imports\n", + "\n", + "import enum\n", + "import re\n", + "import string\n", + "\n", + "# We import JAX and some related packages.\n", + "import chex\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax\n", + "\n", + "# We will use tensorflow to handle the dataset\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "# Finally, we import Gemma.\n", + "from colabtools import adhoc_import\n", + "from gemma import params as params_lib\n", + "from gemma import sampler as sampler_lib\n", + "from gemma import transformer as transformer_lib\n", + "from sentencepiece.src.python import sentencepiece_processor as spm" + ], + "metadata": { + "id": "ac9DJWz9m4AP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "vocab = spm.SentencePieceProcessor()\n", + "vocab.Load(vocab_path)\n", + "\n", + "vocab_list = [(id, vocab.IdToPiece(id)) for id in range(vocab.GetPieceSize())]\n", + "letters = ['A', 'B', 'C', 'D']\n", + "res_dict = {}\n", + "for id, piece in vocab_list:\n", + " try:\n", + " letter = piece[piece.find(next(filter(str.isalpha, piece)))]\n", + " if letter in letters:\n", + " res_dict[id] = letter\n", + " except:\n", + " pass\n" + ], + "metadata": { + "id": "gRpGqxBTm7O0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class GemmaTokenizer:\n", + " \"\"\"Custom wrapper around a SentencePieceProcessor for tensorflow.\"\"\"\n", + "\n", + " def __init__(self,\n", + " spm_processor: spm.SentencePieceProcessor):\n", + " self._spm_processor = spm_processor\n", + "\n", + " @property\n", + " def pad_id(self) -> int:\n", + " \"\"\"Fast access to the pad id.\"\"\"\n", + " return self._spm_processor.pad_id()\n", + "\n", + " def tokenize(self,\n", + " example: str | bytes,\n", + " prefix: str = '',\n", + " suffix: str = '',\n", + " add_eos: bool = True) -> jax.Array:\n", + " \"\"\"\n", + " Tokenization function.\n", + "\n", + " Args:\n", + " example: input string to tokenize.\n", + " prefix: prefix to add to the input string.\n", + " suffix: suffix to add to the input string.\n", + " add_eos: if True, add an end of sentence token at the end of the output\n", + " sequence.\n", + " Returns:\n", + " Tokens corresponding to the input string.\n", + " \"\"\"\n", + " int_list = [self._spm_processor.bos_id()]\n", + " int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n", + " if add_eos:\n", + " int_list.append(self._spm_processor.eos_id())\n", + "\n", + " return jnp.array(int_list, dtype=jnp.int32)\n", + "\n", + " def tokenize_tf_op(self,\n", + " str_tensor: tf.Tensor,\n", + " prefix: str = '',\n", + " suffix: str = '',\n", + " add_eos: bool = True) -> tf.Tensor:\n", + " \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n", + " encoded = tf.numpy_function(\n", + " self.tokenize,\n", + " [str_tensor, prefix, suffix, add_eos],\n", + " tf.int32)\n", + " encoded.set_shape([None])\n", + " return encoded\n", + "\n", + " def to_string(self, tokens: jax.Array) -> str:\n", + " \"\"\"Convert an array of tokens to a string.\"\"\"\n", + " return self._spm_processor.EncodeIds(tokens.tolist())" + ], + "metadata": { + "id": "avEjIH3bn4hX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "tokenizer = GemmaTokenizer(vocab)\n" + ], + "metadata": { + "id": "Bz1RPkw9oBmr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Base class for dataset builders.\"\"\"\n", + "\n", + "import chex\n", + "import jax\n", + "import tensorflow as tf\n", + "\n", + "@chex.dataclass(frozen=True)\n", + "class TrainingInput:\n", + " # Input tokens provided to model\n", + " input_tokens: jax.Array\n", + "\n", + " # A mask that determines which tokens contribute to the target loss\n", + " # calculation\n", + " target_mask: jax.Array\n", + "\n", + "\n", + "class DatasetBuilder:\n", + " \"\"\"Base class for dataset builders.\n", + "\n", + " This class provides the interface for dataset builders.\n", + " \"\"\"\n", + "\n", + " def __init__(self, tokenizer: GemmaTokenizer,\n", + " max_seq_len: int):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " def _pad_up_to_max_len(\n", + " self, input_tensor: tf.Tensor, pad_value: int | bool\n", + " ) -> tf.Tensor:\n", + " \"\"\"Pads the given tensor up to max_seq_len.\"\"\"\n", + " seq_len = tf.shape(input_tensor)[0]\n", + " to_pad = tf.maximum(0, self._max_seq_len - seq_len)\n", + " return tf.pad(\n", + " input_tensor,\n", + " [[0, to_pad]],\n", + " mode='CONSTANT',\n", + " constant_values=pad_value\n", + " )\n", + "\n", + " def get_train_dataset(self):\n", + " raise NotImplementedError()\n", + "\n", + " def get_validation_dataset(self, batch_size: int):\n", + " raise NotImplementedError()\n" + ], + "metadata": { + "id": "BbZ11EuHoCBK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Dataset builder for the Wikipedia datasets.\"\"\"\n", + "\n", + "import enum as Enum\n", + "import random\n", + "\n", + "from absl import logging\n", + "import jax.dlpack\n", + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds\n", + "\n", + "\n", + "topic_wise_save_path = 'home/shivguptashi/wikidata/topic_wise_tfds'\n", + "\n", + "class DatasetSplit(Enum.Enum):\n", + " TRAIN = 'train'\n", + "\n", + "\n", + "class WikipediaDatasetBuilder(DatasetBuilder):\n", + " \"\"\"Dataset builder for the Open Orca dataset.\"\"\"\n", + "\n", + " N_ITEMS = {DatasetSplit.TRAIN: 2914896}\n", + "\n", + " #BUFFER_SIZE_SHUFFLE = 10_000\n", + " BUFFER_SIZE_SHUFFLE = 100\n", + " TEXT_PREFIX = 'Text: \\n'\n", + " TEXT_SUFFIX = '\\n'\n", + " TITLE_PREFIX = 'Title: \\n'\n", + " TITLE_SUFFIX = '\\n'\n", + " #TRANSLATION_PREFIX = 'Translate this into French:\\n'\n", + " #TRANSLATION_SUFFIX = '\\n'\n", + "\n", + " def __init__(\n", + " self, tokenizer: GemmaTokenizer, max_seq_len: int, topic_index: int\n", + " ):\n", + " \"\"\"Constructor.\n", + "\n", + " Args:\n", + " tokenizer: Gemma tokenizer to use.\n", + " max_seq_len: size of each sequence in a given batch.\n", + " \"\"\"\n", + " self._tokenizer = tokenizer\n", + " self._base_data = {\n", + " DatasetSplit.TRAIN: tf.data.Dataset.load(topic_wise_save_path + '_topic_' + str(topic_index)),\n", + " }\n", + " print(f'Topic {topic_index} size: {self._base_data[DatasetSplit.TRAIN].cardinality().numpy()}')\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " sample_ds = self._base_data[DatasetSplit.TRAIN].take(2)\n", + " for x in sample_ds:\n", + " print(x[0])\n", + "\n", + " def _tokenize_title(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Question.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.TITLE_PREFIX,\n", + " suffix=self.TITLE_SUFFIX,\n", + " add_eos=False,\n", + " )\n", + "\n", + " def _tokenize_text(self, example: tf.Tensor):\n", + " \"\"\"Tokenization function for the Response.\"\"\"\n", + " return self._tokenizer.tokenize_tf_op(\n", + " example,\n", + " prefix=self.TEXT_PREFIX,\n", + " suffix=self.TEXT_SUFFIX,\n", + " add_eos=True,\n", + " )\n", + "\n", + " def _to_training_input(\n", + " self,\n", + " title_tokens: jax.Array,\n", + " text_tokens: jax.Array,\n", + " ):\n", + " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", + "\n", + " # The input sequence fed to the model is simply the concatenation of the\n", + " # source and the destination.\n", + " tokens = tf.concat(\n", + " [title_tokens, text_tokens], axis=0\n", + " )\n", + "\n", + " # To prevent the model from updating based on the source (input)\n", + " # tokens, add a target mask to each input.\n", + " title_mask = tf.ones_like(title_tokens, dtype=tf.bool)\n", + " text_mask = tf.ones_like(text_tokens, dtype=tf.bool)\n", + "\n", + " mask = tf.concat([title_mask, text_mask], axis=0)\n", + " # If the output tokens sequence is smaller than the target sequence size,\n", + " # then pad it with pad tokens.\n", + " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", + "\n", + " # Don't want to perform the backward pass on the pad tokens.\n", + " mask = self._pad_up_to_max_len(mask, False)\n", + " return TrainingInput( #type: ignore\n", + " input_tokens=tokens, #type:ignore\n", + " target_mask=mask, #type:ignore\n", + " )# type: ignore\n", + "\n", + " def get_train_dataset(self):\n", + " \"\"\"Build the training dataset.\"\"\"\n", + "\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x, y, z: (\n", + " self._tokenize_title(x),\n", + " self._tokenize_text(y),\n", + " ),\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " )\n", + " ds = ds.map(lambda x, y: self._to_training_input(x, y),\n", + " num_parallel_calls=tf.data.AUTOTUNE)\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", + " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", + " #ds = ds.repeat(num_epochs)\n", + " #ds = ds.batch(batch_size, drop_remainder=True)\n", + " return ds\n", + "\n", + " def get_validation_dataset(self, batch_size: int):\n", + " \"\"\"Build the validation dataset.\"\"\"\n", + "\n", + " # Same steps as in `get_train_dataset`, but without shuffling and\n", + " # repetition.\n", + " # ds = self._base_data[DatasetSplit.VALIDATION].map(\n", + " # lambda x: (self._tokenize_source(x['src']),\n", + " # self._tokenize_destination(x['dst'])))\n", + " ds = self._base_data[DatasetSplit.TRAIN].map(\n", + " lambda x: (\n", + " self._tokenize_title(x[0]),\n", + " self._tokenize_text(x[1]),\n", + " ),\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " )\n", + " ds = ds.map(\n", + " lambda x, y: self._to_training_input(x, y),\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " )\n", + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", + " # ds = ds.batch(batch_size, drop_remainder=True)\n", + " return ds\n", + " # ds = [self._to_training_input(x, y) for x, y in ds]\n", + " # print('here3:', ds)\n", + " # ds = [x for x in ds if tf.shape(x.input_tokens)[0] <= self._max_seq_len]\n", + " # ds = [ds[i : i + batch_size] for i in range(0, len(ds), batch_size)]\n" + ], + "metadata": { + "id": "Bg4vJtuvoJJm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wiki_tokenized_path = 'home/shivguptashi/open_orca/wiki_tokenized'\n", + "tokenizer = GemmaTokenizer(vocab)\n", + "for topic in range(54, 64):\n", + " wikipedia_dataset_builder = WikipediaDatasetBuilder(tokenizer, max_seq_len=1000, topic_index=topic)\n", + " train_ds = wikipedia_dataset_builder.get_train_dataset()\n", + " train_ds = train_ds.as_numpy_iterator()\n", + " it = 0\n", + " cur_tokenized_path = wiki_tokenized_path + '_topic_' + str(topic) + '.tfrecord'\n", + " with tf.io.TFRecordWriter(cur_tokenized_path) as writer:\n", + " for train_record in train_ds:\n", + " record_bytes = tf.train.Example( features=tf.train.Features(feature={'input_tokens': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.input_tokens.tobytes()])), \"target_mask\": tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_record.target_mask.tobytes()]))})).SerializeToString()\n", + " writer.write(record_bytes)\n", + " print(f'it: {it}')\n", + " it += 1" + ], + "metadata": { + "id": "eDxlHK4hvlpY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "PFfl0nrlw176" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/precondition/datamix_gemma/Wikipedia_processing.ipynb b/precondition/datamix_gemma/Wikipedia_processing.ipynb new file mode 100644 index 0000000..c7faff7 --- /dev/null +++ b/precondition/datamix_gemma/Wikipedia_processing.ipynb @@ -0,0 +1,222 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PE4ZRoyy4qGp" + }, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds\n", + "import csv\n", + "import sys\n", + "import tensorflow as tf\n", + "\n", + "wiki_tfds = tfds.load('wikipedia/20190301.en', split='train')" + ] + }, + { + "cell_type": "code", + "source": [ + "sample = wiki_tfds.take(2)\n", + "for x in sample:\n", + " print(x)" + ], + "metadata": { + "id": "UXNdsgd55QZs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "csv.field_size_limit(sys.maxsize)\n", + "\n", + "joined_csv_file_path = 'home/shivguptashi/wikidata/joined_table.csv'\n", + "id_to_info = {}\n", + "title_to_info = {}\n", + "with gfile.Open(joined_csv_file_path, 'r') as f:\n", + " csvreader = csv.reader(f, delimiter=',')\n", + " it = 0\n", + " for row in csvreader:\n", + " title = row[-7].strip('|')\n", + " #print(title)\n", + " title_to_info[title] = row\n", + " it += 1\n", + " if it % 1000 == 0:\n", + " print(it)\n" + ], + "metadata": { + "id": "J4WTWdSk6Wg3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "num_rows_in_info = len(title_to_info['Tatrapan'])\n", + "print(num_rows_in_info)\n", + "print(num_rows_in_info - 7 - 4)" + ], + "metadata": { + "id": "ebUptuXE83TD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from operator import itemgetter\n", + "\n", + "def map_to_features(title, text):\n", + " str_title = title.numpy().decode('utf-8')\n", + " #try:\n", + " if str_title not in title_to_info:\n", + " return (title, text, -1)\n", + " all_entries = [float(y) for y in title_to_info[str_title][4:num_rows_in_info-7]]\n", + " index, _ = max(enumerate(all_entries), key=itemgetter(1))\n", + " return (title, text, index)\n", + " #except:\n", + " # return (title, text, -1)\n", + "\n", + "wiki_tfds = wiki_tfds.map(lambda x: (x['title'], x['text']), num_parallel_calls=16)\n", + "wiki_tfds = wiki_tfds.map(lambda x, y: tf.py_function(map_to_features, [x, y], [tf.string, tf.string, tf.int32]), num_parallel_calls=16)\n" + ], + "metadata": { + "id": "Rc8U4zgE6iub" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "save_path = 'home/shivguptashi/wikidata/wiki_tfds_with_topic'\n", + "tf.data.Dataset.save(wiki_tfds, save_path)\n", + "#from functools import partial\n", + "#def filter_f(x, i):\n", + "# try:\n", + "# str_title = x.numpy().decode('utf-8')\n", + "# all_entries = [int(y) for y in title_to_info[str_title][4:num_rows_in_info-7]]\n", + "# print(all_entries)\n", + "# if int(title_to_info[str_title][i]) >= max(all_entries):\n", + "# return True\n", + "# return False\n", + "# except:\n", + "# return False\n", + "#\n", + "#categories_tfds = []\n", + "#for i in range(4, num_rows_in_info - 7):\n", + " #categories_tfds.append(wiki_tfds.filter(lambda x, y: tf.py_function(partial(filter_f, i=i), [x,], [tf.bool,])[0]))" + ], + "metadata": { + "id": "KJI0Cq499oYv" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sample_tfds = wiki_tfds.take(2)\n", + "for x in sample_tfds:\n", + " print(x[0])" + ], + "metadata": { + "id": "TDlr690KwFLk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "sample_category_tfds = categories_tfds[0].take(2)\n", + "for x in sample_category_tfds:\n", + " print(x)" + ], + "metadata": { + "id": "5ARFZVx00J4U" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "wiki_tfds_with_topic = tf.data.Dataset.load(save_path)\n", + "sample_tfds_with_topic = wiki_tfds_with_topic.take(10)\n", + "sample_wiki_tfds_with_topic = sample_tfds_with_topic.filter(lambda x, y, z: tf.py_function(filter_f, [z], [tf.bool])[0])\n", + "for x in sample_wiki_tfds_with_topic:\n", + " print(x)" + ], + "metadata": { + "id": "iim5buMZIfiU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from functools import partial\n", + "def filter_f(topic_ind, i):\n", + " if topic_ind == i:\n", + " return True\n", + " return False\n", + "\n", + "topic_wise_save_path = 'home/shivguptashi/wikidata/topic_wise_tfds'\n", + "for i in range(58, num_rows_in_info - 7):\n", + " ind = i - 4\n", + " cur_topic_save_path = topic_wise_save_path + '_topic_' + str(ind)\n", + " filtered_wiki_tfds = wiki_tfds_with_topic.filter(lambda x, y, z: tf.py_function(partial(filter_f, i=ind), [z], [tf.bool])[0])\n", + " print('index:', ind)\n", + " tf.data.Dataset.save(filtered_wiki_tfds, cur_topic_save_path)" + ], + "metadata": { + "id": "6248Q9YpH_Em" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(num_rows_in_info)" + ], + "metadata": { + "id": "1PItow2MGpIo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "SEyLBVkDo-FM" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file