Skip to content

Commit

Permalink
use safetensors for the latest plamo-13b repo
Browse files Browse the repository at this point in the history
  • Loading branch information
okdshin committed Oct 11, 2023
1 parent 074bd14 commit fccb147
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions convert-plamo-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
import gguf
from sentencepiece import SentencePieceProcessor # type: ignore[import]

try:
from safetensors import safe_open
except ImportError:
print("Please install `safetensors` python package")
sys.exit(1)


def count_model_parts(dir_model: Path) -> int:
# get number of model parts
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
if filename.startswith("model-00"):
num_parts += 1

if num_parts > 0:
Expand Down Expand Up @@ -161,22 +168,22 @@ def parse_args() -> argparse.Namespace:
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
part_names = iter(("model.safetensors",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu")
model_part = safe_open(dir_model / part_name, framework="pt")

for name in model_part.keys():
if "self_attn.rotary_emb.inv_freq" in name:
continue
data = model_part[name]
data = model_part.get_tensor(name)

old_dtype = data.dtype

Expand Down

0 comments on commit fccb147

Please sign in to comment.