diff --git a/passkey.py b/passkey.py index eb20d7a..71236aa 100644 --- a/passkey.py +++ b/passkey.py @@ -79,7 +79,7 @@ def main(args): loaded = load_model(model, args.load_in_8bit, args.load_in_4bit, args.max_tokens + args.tokens_step) apply_patches(loaded, args.max_tokens + args.tokens_step, args.dynamic_ntk, - args.dynamic_linear, args.ntk, args.linear) + args.dynamic_linear, args.ntk, args.linear, args.part_ntk) pipe = pipeline("text-generation", model=loaded, tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id) @@ -120,6 +120,7 @@ def main(args): parser.add_argument("--dynamic-ntk", type=float) parser.add_argument("--ntk", type=float) parser.add_argument("--linear", type=float) + parser.add_argument("--part-ntk", type=float) parser.add_argument("--load-in-8bit", action="store_true") parser.add_argument("--load-in-4bit", action="store_true") main(parser.parse_args()) diff --git a/quality.py b/quality.py index 1668953..8dbb180 100644 --- a/quality.py +++ b/quality.py @@ -25,7 +25,7 @@ def main(args): model = load_model(args.model, args.load_in_8bit, args.load_in_4bit, args.max_tokens) apply_patches(model, args.max_tokens, args.dynamic_ntk, - args.dynamic_linear, args.ntk, args.linear) + args.dynamic_linear, args.ntk, args.linear, args.part_ntk) choice_tokens = [x[0] for x in tokenizer(CHOICES, add_special_tokens=False).input_ids] decoded_choice = tokenizer.decode(choice_tokens, clean_up_tokenization_spaces=True) @@ -70,6 +70,7 @@ def main(args): parser.add_argument("--dynamic-ntk", type=float) parser.add_argument("--ntk", type=float) parser.add_argument("--linear", type=float) + parser.add_argument("--part-ntk", type=float) parser.add_argument("--load-in-8bit", action="store_true") parser.add_argument("--load-in-4bit", action="store_true") parser.add_argument("--limit", type=int)