Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemma 2 quickstart #14

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 75 additions & 75 deletions Gemma/Keras_Gemma_2_Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@
"id": "PXNm5_p_oxMF"
},
"source": [
"This is a quick demo of Gemma running on KerasNLP. To run this you will need:\n",
"- To be added to a private github repo for Gemma.\n",
"- To be added to a private Kaggle model for weights.\n",
"This is a quick demo of Gemma running on KerasNLP.\n",
"\n",
"Note that you will need a large GPU (e.g. A100) to run this as well.\n",
"\n",
"General Keras reading:\n",
"- [Getting started with Keras](https://keras.io/getting_started/)\n",
"- [Getting started with KerasNLP](https://keras.io/guides/keras_nlp/getting_started/)\n",
"- [Generation and fine-tuning guide for GPT2](https://keras.io/guides/keras_nlp/getting_started/)\n",
"\n",
"<table align=\"left\">\n",
" <td>\n",
Expand Down Expand Up @@ -76,7 +73,9 @@
"from google.colab import userdata\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\"."
]
},
{
Expand All @@ -90,36 +89,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"id": "bMboT70Xop8G"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m21.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m72.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m35.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m589.8/589.8 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m95.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m76.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m107.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Building wheel for keras-nlp (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.1 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"outputs": [],
"source": [
"# Install all deps\n",
"!pip install keras\n",
"!pip install keras-nlp"
"!pip install -U keras-nlp\n",
"!pip install -U keras==3.3.3"
]
},
{
Expand All @@ -137,42 +115,29 @@
"metadata": {
"id": "ww83zI9ToPso"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # Or \"tensorflow\" or \"torch\".\n",
"\n",
"import keras_nlp\n",
"import keras\n",
"\n",
"# Run at half precision.\n",
"keras.config.set_floatx(\"bfloat16\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "yygIK9DEIldp"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/metadata.json...\n",
"100%|██████████| 143/143 [00:00<00:00, 179kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/task.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/config.json...\n",
"100%|██████████| 780/780 [00:00<00:00, 895kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/model.weights.h5...\n",
"100%|██████████| 17.2G/17.2G [18:34<00:00, 16.6MB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/preprocessor.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/tokenizer.json...\n",
"100%|██████████| 315/315 [00:00<00:00, 431kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/3/download/assets/tokenizer/vocabulary.spm...\n",
"100%|██████████| 4.04M/4.04M [00:01<00:00, 2.41MB/s]\n"
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/metadata.json...\n",
"100%|██████████| 143/143 [00:00<00:00, 153kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/task.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/config.json...\n",
"100%|██████████| 780/780 [00:00<00:00, 884kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.weights.h5...\n",
"100%|██████████| 17.2G/17.2G [04:22<00:00, 70.5MB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/model.safetensors.index.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/preprocessor.json...\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/tokenizer.json...\n",
"100%|██████████| 315/315 [00:00<00:00, 434kB/s]\n",
"Downloading from https://www.kaggle.com/api/v1/models/keras/gemma2/keras/gemma2_9b_en/1/download/assets/tokenizer/vocabulary.spm...\n",
"100%|██████████| 4.04M/4.04M [00:00<00:00, 14.6MB/s]\n"
]
},
{
Expand Down Expand Up @@ -300,34 +265,69 @@
}
],
"source": [
"# Connect using the default `gemma2_9b_keras` or through huggingface weights `hf://google/gemma-2-9b-keras`\n",
"import keras_nlp\n",
"import keras\n",
"\n",
"# Run at half precision.\n",
"keras.config.set_floatx(\"bfloat16\")\n",
"\n",
"# using 9B base model\n",
"gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(\"gemma2_9b_en\")\n",
"gemma_lm.summary()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {
"id": "aae5GHrdpj2_"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'What is the meaning of life?\\n\\n[Answer 1]\\n\\nThe meaning of life is to live it.\\n\\n[Answer 2]\\n\\nThe'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"It was a dark and stormy night.\n",
"\n",
"The wind was howling, the rain was pouring, and the thunder was rumbling.\n",
"\n",
"I was sitting in my living room, watching the storm rage outside.\n",
"\n",
"Suddenly, I heard a knock at the door.\n",
"\n",
"I got up and opened it, and there stood a man in a black cloak.\n",
"\n",
"He had a strange look in his eyes, and he was holding a lantern.\n",
"\n",
"\"Who are you?\" I asked.\n",
"\n",
"\"I am the storm,\" he replied.\n",
"\n",
"\"And I have come to take you away.\"\n",
"\n",
"I was terrified, but I couldn't move.\n",
"\n",
"The man in the black cloak grabbed my arm and pulled me out into the storm.\n",
"\n",
"We walked for what seemed like hours, until we came to a clearing in the woods.\n",
"\n",
"There, the man in the black cloak stopped and turned to me.\n",
"\n",
"\"You are mine now,\" he said.\n",
"\n",
"\"And I will take you to my castle.\"\n",
"\n",
"I tried to fight him off, but he was too strong.\n",
"\n",
"He dragged me into the castle, and I was never seen again.\n",
"\n",
"The end.\n"
]
}
],
"source": [
"gemma_lm.generate(\"What is the meaning of life?\", max_length=32)"
"result = gemma_lm.generate(\"It was a dark and stormy night.\", max_length=256)\n",
"print(result)"
]
}
],
Expand Down
Loading
Loading