Skip to content

Commit

Permalink
Update agile_classifiers.ipynb (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
owahltinez authored Jul 10, 2024
1 parent 3a9e817 commit 8f66c93
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions site/en/gemma/docs/agile_classifiers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -646,37 +646,45 @@
"import numpy as np\n",
"\n",
"\n",
"def softmax_normalization(arr: np.ndarray) -> np.ndarray:\n",
" \"\"\"Normalizes logits values into probabilities summing to one.\"\"\"\n",
" arr_exp = np.exp(arr - np.max(arr))\n",
" return arr_exp / arr_exp.sum()\n",
"\n",
"\n",
"def compute_token_probability(\n",
"def compute_output_probability(\n",
" model: keras_nlp.models.GemmaCausalLM,\n",
" prompt: str,\n",
" target_tokens: list[str],\n",
" target_classes: list[str],\n",
") -> dict[str, float]:\n",
" # Shorthands.\n",
" preprocessor = model.preprocessor\n",
" tokenizer = preprocessor.tokenizer\n",
"\n",
" # Identify output token offset.\n",
" (padding_mask,) = preprocessor.generate_preprocess([prompt])['padding_mask']\n",
" token_offset = sum(padding_mask.numpy()) - 1\n",
"\n",
" # Compute prediction, extract only the next token's logits.\n",
" (logits,) = model.predict([prompt], verbose=0)\n",
" token_logits = logits[token_offset]\n",
" # NOTE: If a token is not found, it will be considered same as \"<unk>\".\n",
" token_unk = tokenizer.token_to_id('<unk>')\n",
"\n",
" # Identify the token indices, which is the same as the ID for this tokenizer.\n",
" # NOTE: If a token is not found, it will be considered same as \"<unk>\".\n",
" token_ids = [tokenizer.token_to_id(token) for token in target_tokens]\n",
" token_ids = [tokenizer.token_to_id(word) for word in target_classes]\n",
"\n",
" # Throw an error if one of the classes maps to a token outside the vocabulary.\n",
" if any(token_id == token_unk for token_id in token_ids):\n",
" raise ValueError('One of the target classes is not in the vocabulary.')\n",
"\n",
" # Preprocess the prompt in a single batch. This is done one sample at a time\n",
" # for illustration purposes, but it would be more efficient to batch prompts.\n",
" preprocessed = model.preprocessor.generate_preprocess([prompt])\n",
"\n",
" # Identify output token offset.\n",
" padding_mask = preprocessed[\"padding_mask\"]\n",
" token_offset = keras.ops.sum(padding_mask) - 1\n",
"\n",
" # Score outputs, extract only the next token's logits.\n",
" vocab_logits = model.score(\n",
" token_ids=preprocessed[\"token_ids\"],\n",
" padding_mask=padding_mask,\n",
" )[0][token_offset]\n",
"\n",
" # Compute the relative probability of each of the requested tokens.\n",
" probabilities = softmax_normalization([token_logits[ix] for ix in token_ids])\n",
" token_logits = [vocab_logits[ix] for ix in token_ids]\n",
" logits_tensor = keras.ops.convert_to_tensor(token_logits)\n",
" probabilities = keras.activations.softmax(logits_tensor)\n",
"\n",
" return dict(zip(target_tokens, probabilities))"
" return dict(zip(target_classes, probabilities.numpy()))"
]
},
{
Expand Down Expand Up @@ -707,10 +715,10 @@
}
],
"source": [
"compute_token_probability(\n",
"compute_output_probability(\n",
" model=model,\n",
" prompt=prompt,\n",
" target_tokens=['Positive', 'Negative'],\n",
" target_classes=['Positive', 'Negative'],\n",
")"
]
},
Expand Down Expand Up @@ -743,7 +751,7 @@
" \"\"\"Agile classifier to be wrapped around a LLM.\"\"\"\n",
"\n",
" # The classes whose probability will be predicted.\n",
" labels: tuple\n",
" labels: tuple[str, ...]\n",
"\n",
" # Provide default instructions and control tokens, can be overridden by user.\n",
" instructions: str = 'Classify the following text into one of the following classes'\n",
Expand Down Expand Up @@ -771,10 +779,10 @@
" x_text: str,\n",
" ) -> list[float]:\n",
" prompt = self.encode_for_prediction(x_text)\n",
" token_probabilities = compute_token_probability(\n",
" token_probabilities = compute_output_probability(\n",
" model=model,\n",
" prompt=prompt,\n",
" target_tokens=self.labels,\n",
" target_classes=self.labels,\n",
" )\n",
" return [token_probabilities[token] for token in self.labels]\n",
"\n",
Expand Down

0 comments on commit 8f66c93

Please sign in to comment.