-
Notifications
You must be signed in to change notification settings - Fork 509
/
hf_release.py
27 lines (19 loc) · 934 Bytes
/
hf_release.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from moondream.hf import Moondream
from moondream.hf.configuration_moondream import MoondreamConfig
MoondreamConfig.register_for_auto_class()
Moondream.register_for_auto_class("AutoModelForCausalLM")
OUT_MODEL = "vikhyatk/moondream-next"
CKPT_DIRS = []
def get_ckpt(filename):
ckpts = [torch.load(f"{dir}/{filename}", map_location="cpu") for dir in CKPT_DIRS]
avg_ckpt = {key: sum(ckpt[key] for ckpt in ckpts) / len(ckpts) for key in ckpts[0]}
return avg_ckpt
config = MoondreamConfig()
model = Moondream(config)
model.vision_encoder.encoder.load_state_dict(get_ckpt("vision_encoder.final.pt"))
model.vision_encoder.projection.load_state_dict(get_ckpt("vision_projection.final.pt"))
model.text_model.load_state_dict(get_ckpt("text_model.final.pt"))
model.region_model.load_state_dict(get_ckpt("region_model.final.pt"))
model = model.to(dtype=torch.float16)
model.push_to_hub(OUT_MODEL, config=config)