Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading Pali Gemma 2 pretrained weights #1

Open
Esaada opened this issue Nov 9, 2024 · 11 comments
Open

loading Pali Gemma 2 pretrained weights #1

Esaada opened this issue Nov 9, 2024 · 11 comments
Labels
good first issue Good for newcomers

Comments

@Esaada
Copy link

Esaada commented Nov 9, 2024

Hi, thanks for your work, sorry for the ignorant question, but is there a way for us to load Pali Gemma weights, is the code suppose to be compatible to load weights, or its uses just some abstract llm as an example?
Thanks a lot!!

@lucidrains
Copy link
Owner

hey Ellen, yes i do intend to add that bit of logic by end of month

let's keep this open

@ramkumarkoppu
Copy link

Hi Phil,

Just tried rudementary method to use this code with the PaliGemma weights, it seems to work but I am not sure if these generated action prediction make any sense:

`
import torch
import time
from pi_zero_pytorch import PiZero
from transformers import PaliGemmaForConditionalGeneration

Load PaliGemma model

model_id = "google/paligemma-3b-pt-224"
paligemma_model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)

Initialize π₀ model

model = PiZero(
dim=512, # Ensure this matches the embedding dimension of PaliGemma
dim_action_input=6,
dim_joint_state=12,
num_tokens=20_000
)

Transfer weights from PaliGemma to π₀ model

paligemma_state_dict = paligemma_model.state_dict()
pi_zero_state_dict = model.state_dict()

Map weights (this is a simplified example; actual mapping depends on the specific layers and architecture)

for name, param in paligemma_state_dict.items():
if name in pi_zero_state_dict:
pi_zero_state_dict[name].data.copy_(param.data)

Load the updated state dictionary into π₀ model

model.load_state_dict(pi_zero_state_dict)

Step 2: Save the Updated Model

torch.save(model.state_dict(), 'pi_zero_with_paligemma_weights.pth')

print("Weights saved as pi_zero_with_paligemma_weights.pth")

model.eval() # Switch the model to evaluation mode if needed

Testing Example

vision = torch.randn(1, 1024, 512)
commands = torch.randint(0, 20_000, (1, 1024))
joint_state = torch.randn(1, 12)
actions = torch.randn(1, 32, 6)

start = time.time()

Forward pass to compute loss

loss, _ = model(vision, commands, joint_state, actions)
loss.backward()

After much training, sample actions from the model

with torch.no_grad(): # To avoid computing gradients for sampling
sampled_actions = model(vision, commands, joint_state, trajectory_length=32) # (1, 32, 6)
end = time.time()

print(f"sampled_actions = {sampled_actions}")
print(f"type(sampled_actions) = {type(sampled_actions)}")
print(f"sampled_actions.shape = {sampled_actions.shape}")
print(f"Took {round(end - start)} secs to run")
`

outputs
python example2.py config.hidden_actis ignored, you should useconfig.hidden_activationinstead. Gemma's activation function will be set togelu_pytorch_tanh. Please, use config.hidden_activationif you want to override this behaviour. See https://github.com/huggingface/transformers/pull/29402 for more details. Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 10.10it/s] Weights saved as pi_zero_with_paligemma_weights.pth sampling action trajectory: 34it [00:02, 12.07it/s] sampled_actions = tensor([[[-0.1723, -0.7755, 0.6327, 0.7349, -1.6040, 1.2686], [ 1.4642, -0.5775, -0.3487, 0.7441, -0.3147, -1.4054], [-1.8057, -1.4600, -0.2791, 0.7410, -0.4771, -1.0375], [-1.3005, -1.2288, -0.1237, 1.6539, 0.4367, 1.3989], [-1.8131, -0.8878, 1.1737, -0.1739, -1.9175, 0.4355], [-0.6799, 1.7663, 0.1488, -0.7766, -2.5154, 0.1274], [-0.5450, 1.2790, 0.3630, 0.3508, -1.4404, 1.1594], [-0.0078, -0.2722, -0.7075, 0.1469, -1.8668, 0.1883], [-0.1165, 0.7380, -0.2884, 0.1066, -0.2059, -1.1648], [-0.3130, -0.0812, 0.3078, 0.3247, -2.7376, 0.6835], [ 0.2080, -0.6108, 0.8127, 0.2690, -1.2925, 1.1082], [-0.7570, -0.0700, -0.1674, -1.3282, -1.0847, -0.7158], [-0.5536, -0.6959, 0.1424, 1.2095, 0.0700, 1.0461], [-1.8139, -0.0976, 1.0885, -0.0374, -1.0290, -0.1475], [-0.9862, -0.8747, 1.6191, 1.2776, -0.4705, 0.5143], [ 0.2869, 1.4620, -0.5515, -0.4454, -0.6797, -2.1260], [-0.5725, -1.6888, -0.1353, -0.9105, -1.2810, 1.2304], [-0.2260, -1.5302, 1.7306, 1.6854, -0.6908, 1.5657], [-0.8331, -1.8935, 1.9204, -1.8620, 0.3710, -0.6320], [ 1.1644, -2.3685, -1.0124, -0.0197, -1.4202, -0.9295], [ 0.7564, -0.2496, 1.3316, 0.9873, -1.7427, 0.5434], [-0.0870, -1.1866, -0.4180, -0.4715, -0.6236, 1.8536], [ 0.2101, 0.0397, -0.6447, -2.3438, 0.1627, -1.5687], [-0.5759, -0.6482, 0.5726, -0.8485, -0.5982, 0.1980], [-0.0710, -0.4460, 1.6080, 0.6518, 0.7863, 0.3910], [ 1.8377, -0.9537, -0.4827, -0.8834, -2.7847, 0.4176], [ 0.1236, -2.0394, -0.8019, -0.9687, -0.4023, 0.9074], [-1.0085, 0.5188, 0.7453, -0.8071, 0.0059, -1.5597], [-0.7450, -1.2613, -0.4698, 1.2632, 1.0082, 0.5577], [-0.4213, 0.4462, 2.6775, -0.6689, 0.1974, 0.2159], [ 0.3620, -1.9819, 0.0166, -0.8738, -0.3965, 0.4933], [ 0.1106, -0.3210, -0.7233, 0.0897, 1.0666, 0.0946]]]) type(sampled_actions) = <class 'torch.Tensor'> sampled_actions.shape = torch.Size([1, 32, 6]) Took 6 secs to run

@lucidrains
Copy link
Owner

@ramkumarkoppu hey Ram! thanks for tackling this issue! you'll still need to train on vision / language / action dataset, but having pretrained PaliGemma weights as the starting point

would you like to attempt to submit a pull request with what you have?

that would be a great contribution

@ramkumarkoppu
Copy link

just preparing the PR

@lucidrains
Copy link
Owner

@ramkumarkoppu amazing, thank you Ram! 🙏

@ramkumarkoppu
Copy link

Hi Phil,

I don't want to consume too much of your time but if you get a free time...

after closing the PR, I ran some experiments with manually loading PaliGemma weights into the Pi Zero model using my branch code: https://github.com/ramkumarkoppu/pi-zero-pytorch/tree/load-pretrained-Pali-Gemma-weights/experiments

and I see some of the mismatches like this
Size mismatch: language_model.model.embed_tokens.weight -> token_emb.weight Size mismatch: language_model.model.layers.0.self_attn.q_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.k_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.v_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.o_proj.weight -> layers.0.0.to_out.weight Size mismatch: language_model.model.layers.0.mlp.gate_proj.weight -> layers.0.1.proj_in.weight Size mismatch: language_model.model.layers.0.mlp.down_proj.weight -> layers.0.1.proj_out.weight Size mismatch: language_model.model.layers.0.mlp.up_proj.weight -> layers.0.2.proj_out.weight Size mismatch: language_model.model.norm.weight -> final_norm.weight Mismatched keys (9): [('language_model.model.embed_tokens.weight', 'token_emb.weight'), ('language_model.model.layers.0.self_attn.q_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.k_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.v_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.o_proj.weight', 'layers.0.0.to_out.weight'), ('language_model.model.layers.0.mlp.gate_proj.weight', 'layers.0.1.proj_in.weight'), ('language_model.model.layers.0.mlp.down_proj.weight', 'layers.0.1.proj_out.weight'), ('language_model.model.layers.0.mlp.up_proj.weight', 'layers.0.2.proj_out.weight'), ('language_model.model.norm.weight', 'final_norm.weight')]

the full output is here: https://github.com/ramkumarkoppu/pi-zero-pytorch/blob/load-pretrained-Pali-Gemma-weights/experiments/load_PaliGemma_weights_output.txt

Do you have insights on this issue?

@lucidrains
Copy link
Owner

what you are looking at is the task at hand

you need to match up the parameter weights one by one. complications may arise, such as that here I am doing one projection for qkv. also, the paper claimed that actions needed their own parameters, so you will need to ignore the action specific attention parameters

@lucidrains
Copy link
Owner

@ramkumarkoppu able to figure it out or make some progress?

@ramkumarkoppu
Copy link

Hi Phil,

Sorry I've been a bit ghostly lately! 🕵️‍♂️ After our last discussion, I've been busy prepping for a job change as my current contract is wrapping up. Fingers crossed I can escape the hiring maze soon and get back to working on this. Stay tuned!

@lucidrains
Copy link
Owner

@ramkumarkoppu ah ok, no problem

i'll make some headway on this next week

@lucidrains lucidrains added the good first issue Good for newcomers label Nov 24, 2024
@lucidrains lucidrains changed the title loading Pali Gemma pretrained weights loading Pali Gemma 2 pretrained weights Dec 6, 2024
@kevin-eai2
Copy link

Hi Phil,

I don't want to consume too much of your time but if you get a free time...

after closing the PR, I ran some experiments with manually loading PaliGemma weights into the Pi Zero model using my branch code: https://github.com/ramkumarkoppu/pi-zero-pytorch/tree/load-pretrained-Pali-Gemma-weights/experiments

and I see some of the mismatches like this Size mismatch: language_model.model.embed_tokens.weight -> token_emb.weight Size mismatch: language_model.model.layers.0.self_attn.q_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.k_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.v_proj.weight -> layers.0.0.to_qkv.weight Size mismatch: language_model.model.layers.0.self_attn.o_proj.weight -> layers.0.0.to_out.weight Size mismatch: language_model.model.layers.0.mlp.gate_proj.weight -> layers.0.1.proj_in.weight Size mismatch: language_model.model.layers.0.mlp.down_proj.weight -> layers.0.1.proj_out.weight Size mismatch: language_model.model.layers.0.mlp.up_proj.weight -> layers.0.2.proj_out.weight Size mismatch: language_model.model.norm.weight -> final_norm.weight Mismatched keys (9): [('language_model.model.embed_tokens.weight', 'token_emb.weight'), ('language_model.model.layers.0.self_attn.q_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.k_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.v_proj.weight', 'layers.0.0.to_qkv.weight'), ('language_model.model.layers.0.self_attn.o_proj.weight', 'layers.0.0.to_out.weight'), ('language_model.model.layers.0.mlp.gate_proj.weight', 'layers.0.1.proj_in.weight'), ('language_model.model.layers.0.mlp.down_proj.weight', 'layers.0.1.proj_out.weight'), ('language_model.model.layers.0.mlp.up_proj.weight', 'layers.0.2.proj_out.weight'), ('language_model.model.norm.weight', 'final_norm.weight')]

the full output is here: https://github.com/ramkumarkoppu/pi-zero-pytorch/blob/load-pretrained-Pali-Gemma-weights/experiments/load_PaliGemma_weights_output.txt

Do you have insights on this issue?

404 not found

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

4 participants