Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference UX, accept input data #1285

Merged
merged 24 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _setup_entry_points() -> Dict:
"console_scripts": [
f"deepsparse.transformers.run_inference={data_api_entrypoint}",
f"deepsparse.transformers.eval_downstream={eval_downstream}",
"deepsparse.infer=deepsparse.transformers.infer:main",
"deepsparse.infer=deepsparse.transformers.inference.infer:main",
"deepsparse.debug_analysis=deepsparse.debug_analysis:main",
"deepsparse.analyze=deepsparse.analyze:main",
"deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability",
Expand Down
13 changes: 13 additions & 0 deletions src/deepsparse/transformers/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@
deepsparse.infer models/llama/deployment \
--task text-generation
"""

from typing import Optional

import click

from deepsparse import Pipeline
from deepsparse.tasks import SupportedTasks
from deepsparse.transformers.inference.prompt_parser import PromptParser


@click.command(
Expand All @@ -75,6 +79,14 @@
)
)
@click.argument("model_path", type=str)
@click.option(
"--data",
type=str,
default=None,
help="Path to .txt, .csv, .json, or .jsonl file to load data from"
"If provided, runs inference over the entire dataset. If not provided "
"runs an interactive inference session in the console. Default None.",
)
@click.option(
"--sequence_length",
type=int,
Expand Down Expand Up @@ -112,6 +124,7 @@
)
def main(
model_path: str,
data: Optional[str],
sequence_length: int,
sampling_temperature: float,
prompt_sequence_length: int,
Expand All @@ -128,34 +141,76 @@ def main(
session_ids = "chatbot_cli_session"

pipeline = Pipeline.create(
task=task, # let pipeline determine if task is supported
task=task, # let the pipeline determine if task is supported
model_path=model_path,
sequence_length=sequence_length,
sampling_temperature=sampling_temperature,
prompt_sequence_length=prompt_sequence_length,
)

# continue prompts until a keyboard interrupt
while True:
input_text = input("User: ")
pipeline_inputs = {"prompt": [input_text]}

if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("Bot: ", response.generations[0].text)
if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
if data:
prompt_parser = PromptParser(data)
default_prompt_kwargs = {
"sequence_length": sequence_length,
"sampling_temperature": sampling_temperature,
"prompt_sequence_length": prompt_sequence_length,
"show_tokens_per_sec": show_tokens_per_sec,
}

for prompt_kwargs in prompt_parser.parse_as_iterable(**default_prompt_kwargs):
_run_inference(
task=task,
pipeline=pipeline,
session_ids=session_ids,
**prompt_kwargs,
)
return

# continue prompts until a keyboard interrupt
while data is None: # always True in interactive Mode
prompt = input(">>> ")
_run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
)


def _run_inference(
pipeline,
sampling_temperature,
task,
session_ids,
show_tokens_per_sec,
prompt_sequence_length,
prompt,
**kwargs,
):
pipeline_inputs = dict(
prompt=[prompt],
temperature=sampling_temperature,
**kwargs,
)
if SupportedTasks.is_chat(task):
pipeline_inputs["session_ids"] = session_ids

response = pipeline(**pipeline_inputs)
print("\n", response.generations[0].text)

if show_tokens_per_sec:
times = pipeline.timer_manager.times
prefill_speed = (
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
)
generation_speed = 1.0 / times["engine_token_generation_single"]
print(
f"[prefill: {prefill_speed:.2f} tokens/sec]",
f"[decode: {generation_speed:.2f} tokens/sec]",
sep="\n",
)


if __name__ == "__main__":
Expand Down
108 changes: 108 additions & 0 deletions src/deepsparse/transformers/inference/prompt_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import csv
import json
import os
from enum import Enum
from typing import Iterator


class InvalidPromptSourceDirectoryException(Exception):
pass


class UnableToParseExtentsonException(Exception):
pass


def parse_value_to_appropriate_type(value: str):
if value.isdigit():
return int(value)
if "." in str(value) and all(part.isdigit() for part in value.split(".", 1)):
return float(value)
if value.lower() == "true":
return True
if value.lower() == "false":
return False
return value


class PromptParser:
class Extensions(Enum):
TEXT = ".txt"
CSV = ".csv"
JSON = ".json"
JSONL = ".jsonl"

def __init__(self, filename: str):
self.extention: self.Extensions = self._validate_and_return_extention(filename)
self.filename: str = filename

def parse_as_iterable(self, **kwargs) -> Iterator:
if self.extention == self.Extensions.TEXT:
return self._parse_text(**kwargs)
if self.extention == self.Extensions.CSV:
return self._parse_csv(**kwargs)
if self.extention == self.Extensions.JSON:
return self._parse_json_list(**kwargs)
if self.extention == self.Extensions.JSONL:
return self._parse_jsonl(**kwargs)

raise UnableToParseExtentsonException(
f"Parser for {self.extention} does not exist"
)

def _parse_text(self, **kwargs):
with open(self.filename, "r") as file:
for line in file:
kwargs["prompt"] = line.strip()
yield kwargs

def _parse_csv(self, **kwargs):
with open(self.filename, "r", newline="", encoding="utf-8-sig") as file:
reader = csv.DictReader(file)
for row in reader:
for key, value in row.items():
kwargs.update({key: parse_value_to_appropriate_type(value)})
yield kwargs

def _parse_json_list(self, **kwargs):
with open(self.filename, "r") as file:
json_list = json.load(file)
for json_object in json_list:
kwargs.update(json_object)
yield kwargs

def _parse_jsonl(self, **kwargs):
with open(self.filename, "r") as file:
for jsonl in file:
jsonl_object = json.loads(jsonl)
kwargs.update(jsonl_object)
yield kwargs

def _validate_and_return_extention(self, filename: str):
if os.path.exists(filename):

for extention in self.Extensions:
if filename.endswith(extention.value):
return extention

raise InvalidPromptSourceDirectoryException(
f"{filename} is not compatible. Select file that has "
"extension from "
f"{[key.name for key in self.Extensions]}"
)
raise FileNotFoundError