Skip to content

Commit

Permalink
tool-bench: add --ctk, --ctv, --fa flags
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Feb 27, 2025
1 parent fc19192 commit 60f28ef
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions examples/server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class ServerProcess:
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
ctk: str | None = None
ctv: str | None = None
fa: bool | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
Expand Down Expand Up @@ -151,6 +154,12 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.ctk:
server_args.extend(["-ctk", self.ctk])
if self.ctv:
server_args.extend(["-ctv", self.ctv])
if self.fa is not None:
server_args.append("-fa")
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
Expand Down
9 changes: 9 additions & 0 deletions scripts/tool_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def run(
temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
Expand Down Expand Up @@ -284,6 +287,9 @@ def elapsed():
temp=t,
top_p=top_p,
top_k=top_k,
ctk=ctk,
ctv=ctv,
seed=seed,
success_ratio=float(success_count) / n,
avg_time=mean(success_times + failure_times),
median_time=median(success_times + failure_times),
Expand All @@ -307,6 +313,9 @@ def elapsed():
server.n_ctx = n_ctx
server.n_slots = 1
server.jinja = True
server.ctk = ctk
server.ctv = ctv
server.fa = fa
server.n_predict = n_predict
server.model_hf_repo = hf
server.model_hf_file = None
Expand Down

0 comments on commit 60f28ef

Please sign in to comment.