Skip to content

Commit

Permalink
[Chat pipeline] session context manager (#1276)
Browse files Browse the repository at this point in the history
Co-authored-by: rahul-tuli
  • Loading branch information
bfineran authored Sep 25, 2023
1 parent d241b13 commit d13cc2d
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 3 deletions.
46 changes: 44 additions & 2 deletions src/deepsparse/transformers/pipelines/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextvars
import logging
from typing import Any, Dict, List, Tuple, Type, Union
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import numpy
from pydantic import Field, validator
Expand All @@ -37,6 +39,7 @@


_LOGGER = logging.getLogger(__name__)
_SESSION_IDS_CONTEXT = contextvars.ContextVar("_SESSION_ID", default=None)

__all__ = ["ChatPipeline"]

Expand Down Expand Up @@ -117,6 +120,41 @@ def output_schema(self) -> Type[ChatOutput]:
"""
return ChatOutput

@contextmanager
def session(
self,
session_ids: Union[None, List[str], str] = None,
inference_batch_size: int = 1,
) -> Callable[[Any, Any], Any]:
"""
Context manager that sets and keeps a default session id(s) within
the context
example:
In the following - both responses in the context will share the same
session id
```
with chat_pipeline.session():
first_response = chat_pipeline("first prompt")
second_response = chat_pipeline("second prompt")
```
:param session_ids: actual value to set session ids to in context
must match the inference batch size. If not supplied, will
create default values. Default None
:param inference_batch_size: if generating default session ids, number
of session ids to create. default 1
"""

if session_ids is None:
session_ids = [generate_session_id() for _ in range(inference_batch_size)]

# set session_ids contextvar
token = _SESSION_IDS_CONTEXT.set(session_ids)
yield
# reset session_ids contextvar
_SESSION_IDS_CONTEXT.reset(token)

def process_inputs(
self, inputs: ChatInput
) -> Tuple[List[numpy.ndarray], Dict[str, Any]]:
Expand Down Expand Up @@ -234,7 +272,11 @@ def add_session_ids_to_engine_input(
:return: the engine input with the session ids
"""
session_ids = inputs.session_ids
if session_ids is None:
if session_ids is None and _SESSION_IDS_CONTEXT.get() is not None:
# respect directly setting session IDs first, then try to pull
# from context
session_ids = _SESSION_IDS_CONTEXT.get()
elif session_ids is None:
# session_ids is None, so we need to generate
# a session id for each input sequence
# TODO: Talk to Dipika whether this aligns with the
Expand Down
5 changes: 4 additions & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,17 @@ def parse_inputs(self, *args, **kwargs) -> TextGenerationInput:
these kwargs will be used to instantiate one
:return: parsed TextGenerationInput object
"""
if "sequences" in kwargs and "prompt" not in kwargs:
# support prompt and sequences interchangeably
kwargs["prompt"] = kwargs["sequences"]
if (
args
and not isinstance(args[0], TextGenerationInput)
and "prompt" not in kwargs
and "sequences" not in kwargs
):
# assume first argument is "sequences" (prompt) by default
kwargs["sequences"] = args[0]
kwargs["prompt"] = args[0]
args = args[1:]

return super().parse_inputs(*args, **kwargs)
Expand Down
45 changes: 45 additions & 0 deletions tests/deepsparse/transformers/pipelines/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from deepsparse import Pipeline


@pytest.mark.parametrize(
"pipeline_kwargs",
[
dict(
model_path="zoo:nlg/text_generation/codegen_mono-350m/pytorch/"
"huggingface/bigpython_bigquery_thepile/base-none",
engine_type="onnxruntime",
),
],
)
@pytest.mark.skip(reason="too heavy for now to run in gha")
def test_chat_pipeline_session_manager(pipeline_kwargs):
chat_pipeline = Pipeline.create(task="chat", **pipeline_kwargs)

with chat_pipeline.session():
output_1 = chat_pipeline(
prompt="first", generation_config=dict(max_new_tokens=1)
)
output_2 = chat_pipeline(
prompt="second", generation_config=dict(max_new_tokens=1)
)
# assert inferences in the same context share a session id
assert output_1.session_ids == output_2.session_ids

# test that follow-up inference has a different session id
output_3 = chat_pipeline(prompt="third", generation_config=dict(max_new_tokens=1))
assert output_3.session_ids != output_1.session_ids

0 comments on commit d13cc2d

Please sign in to comment.