-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
loading Pali Gemma 2 pretrained weights #1
Comments
hey Ellen, yes i do intend to add that bit of logic by end of month let's keep this open |
Hi Phil, Just tried rudementary method to use this code with the PaliGemma weights, it seems to work but I am not sure if these generated action prediction make any sense: ` Load PaliGemma modelmodel_id = "google/paligemma-3b-pt-224" Initialize π₀ modelmodel = PiZero( Transfer weights from PaliGemma to π₀ modelpaligemma_state_dict = paligemma_model.state_dict() Map weights (this is a simplified example; actual mapping depends on the specific layers and architecture)for name, param in paligemma_state_dict.items(): Load the updated state dictionary into π₀ modelmodel.load_state_dict(pi_zero_state_dict) Step 2: Save the Updated Modeltorch.save(model.state_dict(), 'pi_zero_with_paligemma_weights.pth') print("Weights saved as pi_zero_with_paligemma_weights.pth") model.eval() # Switch the model to evaluation mode if needed Testing Examplevision = torch.randn(1, 1024, 512) start = time.time() Forward pass to compute lossloss, _ = model(vision, commands, joint_state, actions) After much training, sample actions from the modelwith torch.no_grad(): # To avoid computing gradients for sampling print(f"sampled_actions = {sampled_actions}") outputs |
@ramkumarkoppu hey Ram! thanks for tackling this issue! you'll still need to train on vision / language / action dataset, but having pretrained PaliGemma weights as the starting point would you like to attempt to submit a pull request with what you have? that would be a great contribution |
just preparing the PR |
@ramkumarkoppu amazing, thank you Ram! 🙏 |
Hi Phil, I don't want to consume too much of your time but if you get a free time... after closing the PR, I ran some experiments with manually loading PaliGemma weights into the Pi Zero model using my branch code: https://github.com/ramkumarkoppu/pi-zero-pytorch/tree/load-pretrained-Pali-Gemma-weights/experiments and I see some of the mismatches like this the full output is here: https://github.com/ramkumarkoppu/pi-zero-pytorch/blob/load-pretrained-Pali-Gemma-weights/experiments/load_PaliGemma_weights_output.txt Do you have insights on this issue? |
what you are looking at is the task at hand you need to match up the parameter weights one by one. complications may arise, such as that here I am doing one projection for qkv. also, the paper claimed that actions needed their own parameters, so you will need to ignore the action specific attention parameters |
@ramkumarkoppu able to figure it out or make some progress? |
Hi Phil, Sorry I've been a bit ghostly lately! 🕵️♂️ After our last discussion, I've been busy prepping for a job change as my current contract is wrapping up. Fingers crossed I can escape the hiring maze soon and get back to working on this. Stay tuned! |
@ramkumarkoppu ah ok, no problem i'll make some headway on this next week |
404 not found |
Hi, thanks for your work, sorry for the ignorant question, but is there a way for us to load Pali Gemma weights, is the code suppose to be compatible to load weights, or its uses just some abstract llm as an example?
Thanks a lot!!
The text was updated successfully, but these errors were encountered: