diff --git a/token_benchmark_ray.py b/token_benchmark_ray.py index a5909aa..db9ac85 100644 --- a/token_benchmark_ray.py +++ b/token_benchmark_ray.py @@ -40,8 +40,9 @@ def get_token_throughput_latencies( additional_sampling_params: Optional[Dict[str, Any]] = None, num_concurrent_requests: int = 1, max_num_completed_requests: int = 500, - test_timeout_s=90, - llm_api="openai", + test_timeout_s: int =90, + llm_api: str = "openai", + log_prompts: bool = False ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: """Get the token throughput and latencies for the given model. @@ -90,6 +91,15 @@ def get_token_throughput_latencies( prompt_tokens_stddev=stddev_input_tokens, expect_output_tokens=num_output_tokens, )) + + if log_prompts: + print("Sending the following prompts:") + print(prompts) + else: + # 'prompts' is an array of tuples where each item is (prompt, token_length) + print("Sending the following prompt sizes:") + print(list(map(lambda prompt_with_token_count: prompt_with_token_count[1], prompts))) + start_time = time.monotonic() iter = 0 pbar = tqdm(total=max_num_completed_requests) @@ -289,6 +299,7 @@ def run_token_benchmark( additional_sampling_params: str, results_dir: str, user_metadata: Dict[str, Any], + log_prompts: bool, ): """ Args: @@ -324,6 +335,7 @@ def run_token_benchmark( stddev_output_tokens=stddev_output_tokens, num_concurrent_requests=num_concurrent_requests, additional_sampling_params=json.loads(additional_sampling_params), + log_prompts=log_prompts, ) if results_dir: @@ -459,6 +471,15 @@ def run_token_benchmark( "name=foo,bar=1. These will be added to the metadata field of the results. " ), ) +args.add_argument( + "--log-prompts", + type=bool, + default=False, + help=( + "If True will log all prompts sent to the model" + ), +) + if __name__ == "__main__": env_vars = dict(os.environ) @@ -485,4 +506,5 @@ def run_token_benchmark( additional_sampling_params=args.additional_sampling_params, results_dir=args.results_dir, user_metadata=user_metadata, + log_prompts=args.log_prompt )