From f3470911c3ebc2e7b7452672efb4125e2cf17768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Susano=20Pinto?= Date: Fri, 6 Dec 2024 15:06:51 +0000 Subject: [PATCH] Fix PaliGemma intructions and big_vision deps. --- big_vision/configs/proj/paligemma/README.md | 5 ++- big_vision/datasets/sequence_packing.py | 38 +-------------------- 2 files changed, 3 insertions(+), 40 deletions(-) diff --git a/big_vision/configs/proj/paligemma/README.md b/big_vision/configs/proj/paligemma/README.md index 7f32696..b7f3b37 100644 --- a/big_vision/configs/proj/paligemma/README.md +++ b/big_vision/configs/proj/paligemma/README.md @@ -255,8 +255,7 @@ export KAGGLE_USERNAME= export KAGGLE_KEY= # See https://www.kaggle.com/models/google/paligemma-2 for a full list of models. -export MODEL_NAME=paligemma-2 -export CKPT_FILE=paligemma2-3b-pt-224.npz.b16 +export MODEL_NAME=paligemma2-3b-pt-224 mkdir ckpts/ cd ckpts/ @@ -264,7 +263,7 @@ cd ckpts/ # Store as a "vanity name" from models/proj/paligemma/paligemma.py curl -L -u $KAGGLE_USERNAME:$KAGGLE_KEY\ -o pt_3b_224.bf16.npz \ - https://www.kaggle.com/api/v1/models/google/paligemma-2/jax/$MODEL_NAME/1/download/$CKPT_FILE + https://www.kaggle.com/api/v1/models/google/paligemma-2/jax/$MODEL_NAME/1/download/$MODEL_NAME.b16.npz ``` As an example, we provide the `forkme.py` config that is based on the easily-adjustable jsonl data source: diff --git a/big_vision/datasets/sequence_packing.py b/big_vision/datasets/sequence_packing.py index d28696b..48966d3 100644 --- a/big_vision/datasets/sequence_packing.py +++ b/big_vision/datasets/sequence_packing.py @@ -21,7 +21,6 @@ from typing import Dict, Optional, List, Union from flax import traverse_util -import grain.tensorflow as tf_grain import tensorflow as tf AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -75,39 +74,4 @@ def pack_dataset( Returns: A `tf.data.Dataset`. """ - def _maybe_join(k): - if isinstance(k, int): - k = (k,) - return FLATTEN_SEPARATOR.join(k) - - if isinstance(key2length, int): - key2length = {_maybe_join(k): key2length for k in keys} - else: - key2length = dict(key2length) # Make new dict, we'll edit in-place. - - def _add_fake_index(x): - x = dict(x) - x[tf_grain.INDEX] = -1 - return x - - def _flatten_dict(x): - return traverse_util.flatten_dict(x, sep=FLATTEN_SEPARATOR) - - def _unflatten_dict(x): - return traverse_util.unflatten_dict(x, sep=FLATTEN_SEPARATOR) - - def _remove_index(x): - x = dict(x) - x.pop(tf_grain.INDEX) - return x - - pack_op = tf_grain.TfBatchAndPack( - batch_size=batch_size or 1, - sequence_lengths=_flatten_dict(key2length)) - - dataset = dataset.map(_add_fake_index, num_parallel_calls=AUTOTUNE) - dataset = dataset.map(_flatten_dict, num_parallel_calls=AUTOTUNE) - dataset = pack_op.apply_to_dataset(dataset) - dataset = dataset.map(_unflatten_dict, num_parallel_calls=AUTOTUNE) - dataset = dataset.map(_remove_index, num_parallel_calls=AUTOTUNE) - return dataset.unbatch() + raise ValueError("Not implemented in OSS yet.")