Skip to content

Commit

Permalink
Update cli.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasantony committed Mar 28, 2023
1 parent 08ebc69 commit fd778f0
Showing 1 changed file with 46 additions and 33 deletions.
79 changes: 46 additions & 33 deletions src/llamacpp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import sys
import argparse
import llamacpp
from typing import Dict


def parse_args_into_params(argv) -> llamacpp.gpt_params:
def parse_args_into_params(argv) -> Dict[str, str]:
"""Parses arguments using argparse based on usage information above"""
parser = argparse.ArgumentParser(description="llama.cpp CLI")
parser.add_argument("-i", "--interactive", action="store_true", help="run in interactive mode")
Expand Down Expand Up @@ -52,20 +53,20 @@ def parse_args_into_params(argv) -> llamacpp.gpt_params:
"--repeat_last_n",
type=int,
default=64,
help="last n tokens to consider for penalize (default: 0)",
help="last n tokens to consider for penalize (default: 64)",
)
parser.add_argument(
"--repeat_penalty",
type=float,
default=1.30,
help="penalize repeat sequence of tokens (default: 0.0)",
help="penalize repeat sequence of tokens (default: 1.30)",
)
parser.add_argument(
"-c",
"--ctx_size",
"--n_ctx",
type=int,
default=4096,
help="size of the prompt context (default: 4096)",
default=512,
help="size of the prompt context (default: 512)",
)
parser.add_argument("--temp", type=float, default=0.8, help="temperature (default: 0.7)")
parser.add_argument(
Expand All @@ -76,14 +77,21 @@ def parse_args_into_params(argv) -> llamacpp.gpt_params:
help="batch size for prompt processing (default: 2)",
)
parser.add_argument("-m", "--model", type=str, default="./models/7B/ggml-model-q4_0.bin", help="model path (default: )")
parser.add_argument("--use_mlock", action="store_true", help="use mlock to lock memory")
parser.add_argument("--memory_f16", action="store_true", help="use half-precision memory")
parser.add_argument("--n_batch", type=int, default=8, help="number of tokens per batch")
parser.add_argument("--n_threads", type=int, default=4, help="number of threads to use")

parser.usage = parser.format_help()

args = parser.parse_args(argv[1:])

if args.interactive or args.instruct:
print("WARNING: interactive mode and instruct mode are currently broken")
return args


def process_interactive_input(model: llamacpp.PyLLAMA):
def process_interactive_input(model: llamacpp.LlamaInference):
"""Process interactive input similar to the C++ version"""

# Read lines as long as user is entering "\" at the end of the line
Expand All @@ -108,39 +116,37 @@ def main(args):

# Add a space in front of the first character to match OG llama tokenizer behavior
args.prompt = " " + args.prompt

# Initialize the gpt_params object
params = llamacpp.gpt_params(
args.model,
args.ctx_size,
args.n_predict,
args.top_k,
args.top_p,
args.temp,
args.repeat_penalty,
args.seed,
args.threads,
args.repeat_last_n,
args.batch_size,
)

model = llamacpp.PyLLAMA(params)
model.add_bos()
params = llamacpp.InferenceParams()
params.path_model = args.model
params.seed = args.seed
params.n_threads = args.n_threads

params.repeat_last_n = args.repeat_last_n
params.n_batch = args.n_batch
params.top_k = args.top_k
params.top_p = args.top_p
params.temp = args.temp
params.repeat_penalty = args.repeat_penalty
params.use_mlock = args.use_mlock
params.memory_f16 = args.memory_f16
params.n_ctx = args.n_ctx

model = llamacpp.LlamaInference(params)
model.update_input([model.token_bos()])
model.update_input(args.prompt)
model.print_startup_stats()
model.prepare_context()
print(model.system_info())

inp_pfx = model.tokenize("\n\n### Instruction:\n\n", True)
inp_sfx = model.tokenize("\n\n### Response:\n\n", False)

if args.instruct:
args.interactive = True
args.antiprompt = "### Instruction:\n\n"
args.reverse_prompt = "### Instruction:\n\n"

# Set antiprompt if we are in interactive mode
if args.antiprompt:
if args.reverse_prompt:
args.interactive = True
model.set_antiprompt(args.antiprompt)

if args.interactive:
print("== Running in interactive mode. ==")
Expand All @@ -153,16 +159,22 @@ def main(args):
input_noecho = False
is_finished = False

while not model.is_finished():
print(args.prompt, end="")

n_output = 0
while n_output < args.n_predict:
if model.has_unconsumed_input():
model.ingest_all_pending_input(not input_noecho)
model.ingest_all_pending_input()
# # reset color to default if we there is no pending user input
# if (!input_noecho && args.use_color) {
# printf(ANSI_COLOR_RESET);
# }
else:
text, is_finished = model.infer_text()
token = model.sample()
text = model.token_to_str(token)
print(text, end="")
n_output += 1
is_finished = token == model.token_eos()
input_noecho = False

if args.interactive:
Expand Down Expand Up @@ -202,5 +214,6 @@ def run():
args = parse_args_into_params(sys.argv)
return main(args)


if __name__ == "__main__":
sys.exit(run())

0 comments on commit fd778f0

Please sign in to comment.