Skip to content

Commit

Permalink
more changes to support chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Dec 9, 2024
1 parent 7c57a5a commit a838911
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
4 changes: 3 additions & 1 deletion run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def run_inference():
'-ngl', '0',
'-c', str(args.ctx_size),
'--temp', str(args.temperature),
"-b", "1"
"-b", "1",
"-cnv" if args.cnv else ""
]
run_command(command)

Expand All @@ -48,6 +49,7 @@ def signal_handler(sig, frame):
parser.add_argument("-t", "--threads", type=int, help="Number of threads to use", required=False, default=2)
parser.add_argument("-c", "--ctx-size", type=int, help="Size of the prompt context", required=False, default=2048)
parser.add_argument("-temp", "--temperature", type=float, help="Temperature, a hyperparameter that controls the randomness of the generated text", required=False, default=0.8)
parser.add_argument("-cnv", "--conversation", ction='store_true', help="Whether to enable chat mode or not (for instruct models.)")

args = parser.parse_args()
run_inference()
3 changes: 3 additions & 0 deletions setup_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
"HF1BitLLM/Llama3-8B-1.58-100B-tokens": {
"model_name": "Llama3-8B-1.58-100B-tokens",
},
"tiiuae/falcon3-7b-instruct-1.58bit": {
"model_name": "falcon3-7b-1.58bit",
},
"tiiuae/falcon3-7b-1.58bit": {
"model_name": "falcon3-7b-1.58bit",
}
Expand Down
2 changes: 1 addition & 1 deletion utils/convert-hf-to-gguf-bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def read_model_config(model_dir: str) -> dict[str, Any]:
with open(config, "r") as f:
return json.load(f)

@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "Falcon3ForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

Expand Down

0 comments on commit a838911

Please sign in to comment.