From 63233d37b0d28fb0116f0c19e697de8d5f90d7d0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 5 Feb 2024 14:58:47 -0800 Subject: [PATCH 01/25] Added Outlines logits processor for JSON schema validation --- proto/generate.proto | 2 + server/lorax_server/models/flash_causal_lm.py | 6 +- server/lorax_server/utils/logits_process.py | 109 +++++++++++++++++- server/lorax_server/utils/tokens.py | 21 ++++ 4 files changed, 136 insertions(+), 2 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index 1743b0b3a..400c500b5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -75,6 +75,8 @@ message NextTokenChooserParameters { bool watermark = 8; /// adapter to use with lora exchange string adapter_id = 9; + /// JSON schema used for constrained decoding (Outlines) + string schema = 10; } message StoppingCriteriaParameters { diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 7cc3112c8..e8b4601e6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -243,8 +243,12 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) + request_tokenizers = [ + tokenizers.get_tokenizer(r.adapter_index, tokenizer) + for r in pb.requests + ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, request_tokenizers, dtype, device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index f424eae40..56d8ecda3 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -7,12 +7,20 @@ from transformers import ( LogitsWarper, LogitsProcessor, + PreTrainedTokenizerBase, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, ) +try: + from outlines.fsm.fsm import RegexFSM + from outlines.fsm.json_schema import build_regex_from_object + HAS_OUTLINES = True +except ImportError: + HAS_OUTLINES = False + mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -118,7 +126,7 @@ def filter(self, indices): return None -class HeterogeneousTemperatureLogitsWarper: +class HeterogeneousTemperatureLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). This version allows for a separate value for each sample and runs inplace when possible. @@ -408,3 +416,102 @@ def filter(self, indices): self.processors = new_processors return self return None + + +class HeterogeneousSchemaLogitsProcessor(LogitsProcessor): + r""" + [`LogitsWarper`] for JSON schema enforcement. + This version uses Outlines to perform the constrained decoding. + + Args: + schemas (`List[Optional[str]]`): + The JSON encoded schemas to enforce. `None` means no enforcement. + tokenizers (`List[Optional[PreTrainedTokenizerBase]]`): + The tokenizers to use for each request. + """ + + def __init__( + self, + schemas: List[Optional[str]], + tokenizers: List[Optional[PreTrainedTokenizerBase]], + ): + self.sequence_processors = [ + None if schema is None else OutlinesLogitsProcessor(schema, tokenizer) + for schema, tokenizer in zip(schemas, tokenizers) + ] + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + for i, processor in enumerate(self.sequence_processors): + if processor is not None: + scores[i:i + 1] = processor(input_ids[i:i + 1], scores[i:i + 1]) + return scores + + def filter(self, indices): + self.sequence_processors = [self.sequence_processors[i] for i in indices] + if any([x is not None for x in self.sequence_processors]): + return self + return None + + +# Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +class OutlinesLogitsProcessor: + def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the regex-guided generation. + + Args: + schema (str): + JSON schema to enforce. + tokenizer (PreTrainedTokenizerBase): + The tokenizer to use for the FSM. + """ + if not HAS_OUTLINES: + raise ImportError("Unable to enforce JSON schema: `outlines` is not installed.") + + tokenizer = self.adapt_tokenizer(tokenizer) + + regex_string = build_regex_from_object(schema) + self.fsm = RegexFSM(regex_string, tokenizer) + self.fsm_state = 0 + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + last_token = input_ids[-1] + self.fsm_state = self.fsm.next_state(self.fsm_state, last_token) + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state) + + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + def adapt_tokenizer(self, tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + """ + if hasattr(tokenizer, "vocabulary"): + # We've already adapted the tokenizer from a previous request + return tokenizer + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 1d4a483d3..810a8fc75 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -18,6 +18,7 @@ HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousProcessorWrapper, + HeterogeneousSchemaLogitsProcessor, ) @@ -200,15 +201,18 @@ class HeterogeneousNextTokenChooser: watermark (List[bool]): A list of booleans indicating whether watermark processing should be applied for each token. temperature (List[float]): A list of temperature values for temperature-based logits warping. repetition_penalty (List[float]): A list of repetition penalty values for repetition penalty-based logits warping. + schema (List[str]): A list of JSON schema strings for Outlines logits warping. top_k (List[int]): A list of top-k values for top-k-based logits warping. top_p (List[float]): A list of top-p values for top-p-based logits warping. typical_p (List[float]): A list of typical-p values for typical-p-based logits warping. do_sample (List[bool]): A list of booleans indicating whether sampling should be applied for each token. seeds (List[int]): A list of seed values for random number generation. + tokenizers (List[PreTrainedTokenizerBase]): A list of tokenizers to use for processing the tokens. Attributes: watermark_processor (HeterogeneousProcessorWrapper): The watermark logits processor. repetition_processor (HeterogeneousRepetitionPenaltyLogitsProcessor): The repetition penalty logits processor. + schema_processor (HeterogeneousSchemaLogitsProcessor): The JSON schema logits processor. warpers (List[HeterogeneousLogitsWarper]): The list of logits warpers. choice (HeterogeneousSampling or Greedy): The token choice strategy. seeds (List[int]): The list of seed values. @@ -224,11 +228,13 @@ def __init__( watermark: List[bool], temperature: List[float], repetition_penalty: List[float], + schemas: List[str], top_k: List[int], top_p: List[float], typical_p: List[float], do_sample: List[bool], seeds: List[int], + tokenizers: List[PreTrainedTokenizerBase], ): warpers = [] @@ -252,6 +258,12 @@ def __init__( else None ) + self.schema_processor = ( + HeterogeneousSchemaLogitsProcessor(schemas, tokenizers) + if any(schemas) + else None + ) + if any([x != 1.0 for x in temperature]): do_sample = [ sample or x != 1.0 for x, sample in zip(temperature, do_sample) @@ -300,6 +312,8 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.schema_processor is not None: + scores = self.schema_processor(input_ids, scores) for warper in self.warpers: scores = warper(input_ids, scores) @@ -326,6 +340,9 @@ def filter(self, indices): if self.repetition_processor is not None: self.repetition_processor = self.repetition_processor.filter(indices) + + if self.schema_processor is not None: + self.schema_processor = self.schema_processor.filter(indices) filtered_warpers = [] for warper in self.warpers: @@ -348,6 +365,7 @@ def filter(self, indices): def from_pb( cls, pb: List[generate_pb2.NextTokenChooserParameters], + tokenizers: List[PreTrainedTokenizerBase], dtype: torch.dtype, device: torch.device, ) -> "HeterogeneousNextTokenChooser": @@ -356,6 +374,7 @@ def from_pb( Args: pb (List[generate_pb2.NextTokenChooserParameters]): The protocol buffer containing the parameters. + tokenizers (List[PreTrainedTokenizerBase]): The tokenizers to use for processing the tokens. dtype (torch.dtype): The data type of the tokens. device (torch.device): The device on which the tokens are processed. @@ -366,11 +385,13 @@ def from_pb( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + schemas=[pb_.schema for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], + tokenizer=tokenizers, device=device, dtype=dtype, ) From e9fae51f0894a6228be97d28a4a35e4ad4c15436 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 5 Feb 2024 15:13:19 -0800 Subject: [PATCH 02/25] Router changes --- proto/generate.proto | 2 +- router/client/src/client.rs | 1 + router/src/health.rs | 1 + router/src/lib.rs | 10 ++++++++++ router/src/validation.rs | 2 ++ 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/proto/generate.proto b/proto/generate.proto index 400c500b5..e4172cde3 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -76,7 +76,7 @@ message NextTokenChooserParameters { /// adapter to use with lora exchange string adapter_id = 9; /// JSON schema used for constrained decoding (Outlines) - string schema = 10; + optional string schema = 10; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index d1922d90e..8bc4b4609 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -126,6 +126,7 @@ impl Client { repetition_penalty: 1.2, watermark: true, adapter_id: "".to_string(), + schema: None, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 2, diff --git a/router/src/health.rs b/router/src/health.rs index 578ac5716..2ebed317c 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -45,6 +45,7 @@ impl Health { repetition_penalty: 1.0, watermark: false, adapter_id: "".to_string(), + schema: None, }), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: 1, diff --git a/router/src/lib.rs b/router/src/lib.rs index 19ab7a191..0c7b3ada5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -249,6 +249,13 @@ pub(crate) struct GenerateParameters { example = "null" )] pub seed: Option, + #[serde(default)] + #[schema( + nullable = true, + default = "null", + example = "{\"type\": \"string\", \"title\": \"response\"}" + )] + pub schema: Option, } fn default_max_new_tokens() -> u32 { @@ -277,6 +284,7 @@ fn default_parameters() -> GenerateParameters { decoder_input_details: false, apply_chat_template: false, seed: None, + schema: None, } } @@ -582,6 +590,7 @@ impl From for CompatGenerateRequest { decoder_input_details: req.logprobs.is_some(), apply_chat_template: false, seed: None, + schema: None, }, stream: req.stream.unwrap_or(false), } @@ -616,6 +625,7 @@ impl From for CompatGenerateRequest { decoder_input_details: false, apply_chat_template: true, seed: None, + schema: None, }, stream: req.stream.unwrap_or(false), } diff --git a/router/src/validation.rs b/router/src/validation.rs index f10949c39..e241c6a02 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -147,6 +147,7 @@ impl Validation { adapter_parameters, decoder_input_details, apply_chat_template, + schema, .. } = request.parameters; @@ -273,6 +274,7 @@ impl Validation { seed, watermark, adapter_id, + schema, }; let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, From 85fbdfd13206bede8a4b63f9f5883b25adba8d7c Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 5 Feb 2024 15:20:29 -0800 Subject: [PATCH 03/25] Added outlines as extra dep --- Dockerfile | 2 +- server/poetry.lock | 451 +++++++++++++++++++++++++++++++++++++++++- server/pyproject.toml | 2 + 3 files changed, 453 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5cfc2dbf4..5ab91b0e6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -218,7 +218,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ - pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir # Install router COPY --from=builder /usr/src/target/release/lorax-router /usr/local/bin/lorax-router diff --git a/server/poetry.lock b/server/poetry.lock index d453f0348..cc2afe601 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -139,6 +139,17 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.6.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = true +python-versions = ">=3.8" +files = [ + {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, + {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, +] + [[package]] name = "async-timeout" version = "4.0.3" @@ -355,6 +366,17 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cloudpickle" +version = "3.0.0" +description = "Pickler class to extend the standard pickle.Pickler functionality" +optional = true +python-versions = ">=3.8" +files = [ + {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, + {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, +] + [[package]] name = "colorama" version = "0.4.6" @@ -440,6 +462,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "diskcache" +version = "5.6.3" +description = "Disk Cache -- Disk and file backed persistent cache." +optional = true +python-versions = ">=3" +files = [ + {file = "diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19"}, + {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, +] + [[package]] name = "einops" version = "0.6.1" @@ -930,6 +963,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "interegular" +version = "0.3.3" +description = "a regex intersection checker" +optional = true +python-versions = ">=3.7" +files = [ + {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"}, + {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, +] + [[package]] name = "jinja2" version = "3.1.2" @@ -958,6 +1002,99 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + +[[package]] +name = "jsonschema" +version = "4.21.1" +description = "An implementation of JSON Schema validation for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.21.1-py3-none-any.whl", hash = "sha256:7996507afae316306f9e2290407761157c6f78002dcf7419acb99822143d1c6f"}, + {file = "jsonschema-4.21.1.tar.gz", hash = "sha256:85727c00279f5fa6bedbe6238d2aa6403bedd8b4864ab11207d07df3cc1b2ee5"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] + +[[package]] +name = "jsonschema-specifications" +version = "2023.12.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = true +python-versions = ">=3.8" +files = [ + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + +[[package]] +name = "lark" +version = "1.1.9" +description = "a modern parsing library" +optional = true +python-versions = ">=3.6" +files = [ + {file = "lark-1.1.9-py3-none-any.whl", hash = "sha256:a0dd3a87289f8ccbb325901e4222e723e7d745dbfc1803eaf5f3d2ace19cf2db"}, + {file = "lark-1.1.9.tar.gz", hash = "sha256:15fa5236490824c2c4aba0e22d2d6d823575dcaf4cdd1848e34b6ad836240fba"}, +] + +[package.extras] +atomic-cache = ["atomicwrites"] +interegular = ["interegular (>=0.3.1,<0.4.0)"] +nearley = ["js2py"] +regex = ["regex"] + +[[package]] +name = "llvmlite" +version = "0.42.0" +description = "lightweight wrapper around basic LLVM functionality" +optional = true +python-versions = ">=3.9" +files = [ + {file = "llvmlite-0.42.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3366938e1bf63d26c34fbfb4c8e8d2ded57d11e0567d5bb243d89aab1eb56098"}, + {file = "llvmlite-0.42.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c35da49666a21185d21b551fc3caf46a935d54d66969d32d72af109b5e7d2b6f"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70f44ccc3c6220bd23e0ba698a63ec2a7d3205da0d848804807f37fc243e3f77"}, + {file = "llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d"}, + {file = "llvmlite-0.42.0-cp310-cp310-win_amd64.whl", hash = "sha256:8d90edf400b4ceb3a0e776b6c6e4656d05c7187c439587e06f86afceb66d2be5"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ae511caed28beaf1252dbaf5f40e663f533b79ceb408c874c01754cafabb9cbf"}, + {file = "llvmlite-0.42.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81e674c2fe85576e6c4474e8c7e7aba7901ac0196e864fe7985492b737dbab65"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb3975787f13eb97629052edb5017f6c170eebc1c14a0433e8089e5db43bcce6"}, + {file = "llvmlite-0.42.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c5bece0cdf77f22379f19b1959ccd7aee518afa4afbd3656c6365865f84903f9"}, + {file = "llvmlite-0.42.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e0c4c11c8c2aa9b0701f91b799cb9134a6a6de51444eff5a9087fc7c1384275"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:08fa9ab02b0d0179c688a4216b8939138266519aaa0aa94f1195a8542faedb56"}, + {file = "llvmlite-0.42.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b2fce7d355068494d1e42202c7aff25d50c462584233013eb4470c33b995e3ee"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebe66a86dc44634b59a3bc860c7b20d26d9aaffcd30364ebe8ba79161a9121f4"}, + {file = "llvmlite-0.42.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47494552559e00d81bfb836cf1c4d5a5062e54102cc5767d5aa1e77ccd2505c"}, + {file = "llvmlite-0.42.0-cp312-cp312-win_amd64.whl", hash = "sha256:05cb7e9b6ce69165ce4d1b994fbdedca0c62492e537b0cc86141b6e2c78d5888"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bdd3888544538a94d7ec99e7c62a0cdd8833609c85f0c23fcb6c5c591aec60ad"}, + {file = "llvmlite-0.42.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0936c2067a67fb8816c908d5457d63eba3e2b17e515c5fe00e5ee2bace06040"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a78ab89f1924fc11482209f6799a7a3fc74ddc80425a7a3e0e8174af0e9e2301"}, + {file = "llvmlite-0.42.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7599b65c7af7abbc978dbf345712c60fd596aa5670496561cc10e8a71cebfb2"}, + {file = "llvmlite-0.42.0-cp39-cp39-win_amd64.whl", hash = "sha256:43d65cc4e206c2e902c1004dd5418417c4efa6c1d04df05c6c5675a27e8ca90e"}, + {file = "llvmlite-0.42.0.tar.gz", hash = "sha256:f92b09243c0cc3f457da8b983f67bd8e1295d0f5b3746c7a1861d7a99403854a"}, +] + [[package]] name = "loguru" version = "0.6.0" @@ -1173,6 +1310,17 @@ files = [ [package.dependencies] dill = ">=0.3.7" +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = true +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "networkx" version = "3.2.1" @@ -1191,6 +1339,40 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "numba" +version = "0.59.0" +description = "compiling Python code using LLVM" +optional = true +python-versions = ">=3.9" +files = [ + {file = "numba-0.59.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d061d800473fb8fef76a455221f4ad649a53f5e0f96e3f6c8b8553ee6fa98fa"}, + {file = "numba-0.59.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c086a434e7d3891ce5dfd3d1e7ee8102ac1e733962098578b507864120559ceb"}, + {file = "numba-0.59.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e20736bf62e61f8353fb71b0d3a1efba636c7a303d511600fc57648b55823ed"}, + {file = "numba-0.59.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e86e6786aec31d2002122199486e10bbc0dc40f78d76364cded375912b13614c"}, + {file = "numba-0.59.0-cp310-cp310-win_amd64.whl", hash = "sha256:0307ee91b24500bb7e64d8a109848baf3a3905df48ce142b8ac60aaa406a0400"}, + {file = "numba-0.59.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d540f69a8245fb714419c2209e9af6104e568eb97623adc8943642e61f5d6d8e"}, + {file = "numba-0.59.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1192d6b2906bf3ff72b1d97458724d98860ab86a91abdd4cfd9328432b661e31"}, + {file = "numba-0.59.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90efb436d3413809fcd15298c6d395cb7d98184350472588356ccf19db9e37c8"}, + {file = "numba-0.59.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd3dac45e25d927dcb65d44fb3a973994f5add2b15add13337844afe669dd1ba"}, + {file = "numba-0.59.0-cp311-cp311-win_amd64.whl", hash = "sha256:753dc601a159861808cc3207bad5c17724d3b69552fd22768fddbf302a817a4c"}, + {file = "numba-0.59.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ce62bc0e6dd5264e7ff7f34f41786889fa81a6b860662f824aa7532537a7bee0"}, + {file = "numba-0.59.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8cbef55b73741b5eea2dbaf1b0590b14977ca95a13a07d200b794f8f6833a01c"}, + {file = "numba-0.59.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:70d26ba589f764be45ea8c272caa467dbe882b9676f6749fe6f42678091f5f21"}, + {file = "numba-0.59.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e125f7d69968118c28ec0eed9fbedd75440e64214b8d2eac033c22c04db48492"}, + {file = "numba-0.59.0-cp312-cp312-win_amd64.whl", hash = "sha256:4981659220b61a03c1e557654027d271f56f3087448967a55c79a0e5f926de62"}, + {file = "numba-0.59.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe4d7562d1eed754a7511ed7ba962067f198f86909741c5c6e18c4f1819b1f47"}, + {file = "numba-0.59.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6feb1504bb432280f900deaf4b1dadcee68812209500ed3f81c375cbceab24dc"}, + {file = "numba-0.59.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:944faad25ee23ea9dda582bfb0189fb9f4fc232359a80ab2a028b94c14ce2b1d"}, + {file = "numba-0.59.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5516a469514bfae52a9d7989db4940653a5cbfac106f44cb9c50133b7ad6224b"}, + {file = "numba-0.59.0-cp39-cp39-win_amd64.whl", hash = "sha256:32bd0a41525ec0b1b853da244808f4e5333867df3c43c30c33f89cf20b9c2b63"}, + {file = "numba-0.59.0.tar.gz", hash = "sha256:12b9b064a3e4ad00e2371fc5212ef0396c80f41caec9b5ec391c8b04b6eaf2a8"}, +] + +[package.dependencies] +llvmlite = "==0.42.*" +numpy = ">=1.22,<1.27" + [[package]] name = "numpy" version = "1.26.2" @@ -1549,6 +1731,39 @@ files = [ {file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"}, ] +[[package]] +name = "outlines" +version = "0.0.26" +description = "Probabilistic Generative Model Programming" +optional = true +python-versions = ">=3.8" +files = [ + {file = "outlines-0.0.26-py3-none-any.whl", hash = "sha256:3ba0e0c8f00001bde35baf22e53c820d44818fb9b40e5220153161fe455b007e"}, + {file = "outlines-0.0.26.tar.gz", hash = "sha256:210d8027286cf9c88626b4052d601ff02e40900392e0c0ec889321e734188a5b"}, +] + +[package.dependencies] +cloudpickle = "*" +diskcache = "*" +interegular = "*" +jinja2 = "*" +joblib = "*" +jsonschema = "*" +lark = "*" +nest-asyncio = "*" +numba = "*" +numpy = "*" +pydantic = ">=2.0" +referencing = "*" +requests = "*" +scipy = "*" +torch = ">=2.1" +transformers = "4.36.2" + +[package.extras] +serve = ["fastapi", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] + [[package]] name = "packaging" version = "23.2" @@ -1844,6 +2059,116 @@ files = [ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] +[[package]] +name = "pydantic" +version = "2.6.1" +description = "Data validation using Python type hints" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic-2.6.1-py3-none-any.whl", hash = "sha256:0b6a909df3192245cb736509a92ff69e4fef76116feffec68e93a567347bae6f"}, + {file = "pydantic-2.6.1.tar.gz", hash = "sha256:4fd5c182a2488dc63e6d32737ff19937888001e2a6d86e94b3f233104a5d1fa9"}, +] + +[package.dependencies] +annotated-types = ">=0.4.0" +pydantic-core = "2.16.2" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.16.2" +description = "" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.16.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3fab4e75b8c525a4776e7630b9ee48aea50107fea6ca9f593c98da3f4d11bf7c"}, + {file = "pydantic_core-2.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8bde5b48c65b8e807409e6f20baee5d2cd880e0fad00b1a811ebc43e39a00ab2"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2924b89b16420712e9bb8192396026a8fbd6d8726224f918353ac19c4c043d2a"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:16aa02e7a0f539098e215fc193c8926c897175d64c7926d00a36188917717a05"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:936a787f83db1f2115ee829dd615c4f684ee48ac4de5779ab4300994d8af325b"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:459d6be6134ce3b38e0ef76f8a672924460c455d45f1ad8fdade36796df1ddc8"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9ee4febb249c591d07b2d4dd36ebcad0ccd128962aaa1801508320896575ef"}, + {file = "pydantic_core-2.16.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40a0bd0bed96dae5712dab2aba7d334a6c67cbcac2ddfca7dbcc4a8176445990"}, + {file = "pydantic_core-2.16.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:870dbfa94de9b8866b37b867a2cb37a60c401d9deb4a9ea392abf11a1f98037b"}, + {file = "pydantic_core-2.16.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:308974fdf98046db28440eb3377abba274808bf66262e042c412eb2adf852731"}, + {file = "pydantic_core-2.16.2-cp310-none-win32.whl", hash = "sha256:a477932664d9611d7a0816cc3c0eb1f8856f8a42435488280dfbf4395e141485"}, + {file = "pydantic_core-2.16.2-cp310-none-win_amd64.whl", hash = "sha256:8f9142a6ed83d90c94a3efd7af8873bf7cefed2d3d44387bf848888482e2d25f"}, + {file = "pydantic_core-2.16.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:406fac1d09edc613020ce9cf3f2ccf1a1b2f57ab00552b4c18e3d5276c67eb11"}, + {file = "pydantic_core-2.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce232a6170dd6532096cadbf6185271e4e8c70fc9217ebe105923ac105da9978"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a90fec23b4b05a09ad988e7a4f4e081711a90eb2a55b9c984d8b74597599180f"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8aafeedb6597a163a9c9727d8a8bd363a93277701b7bfd2749fbefee2396469e"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9957433c3a1b67bdd4c63717eaf174ebb749510d5ea612cd4e83f2d9142f3fc8"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0d7a9165167269758145756db43a133608a531b1e5bb6a626b9ee24bc38a8f7"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dffaf740fe2e147fedcb6b561353a16243e654f7fe8e701b1b9db148242e1272"}, + {file = "pydantic_core-2.16.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8ed79883b4328b7f0bd142733d99c8e6b22703e908ec63d930b06be3a0e7113"}, + {file = "pydantic_core-2.16.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:cf903310a34e14651c9de056fcc12ce090560864d5a2bb0174b971685684e1d8"}, + {file = "pydantic_core-2.16.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:46b0d5520dbcafea9a8645a8164658777686c5c524d381d983317d29687cce97"}, + {file = "pydantic_core-2.16.2-cp311-none-win32.whl", hash = "sha256:70651ff6e663428cea902dac297066d5c6e5423fda345a4ca62430575364d62b"}, + {file = "pydantic_core-2.16.2-cp311-none-win_amd64.whl", hash = "sha256:98dc6f4f2095fc7ad277782a7c2c88296badcad92316b5a6e530930b1d475ebc"}, + {file = "pydantic_core-2.16.2-cp311-none-win_arm64.whl", hash = "sha256:ef6113cd31411eaf9b39fc5a8848e71c72656fd418882488598758b2c8c6dfa0"}, + {file = "pydantic_core-2.16.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:88646cae28eb1dd5cd1e09605680c2b043b64d7481cdad7f5003ebef401a3039"}, + {file = "pydantic_core-2.16.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b883af50eaa6bb3299780651e5be921e88050ccf00e3e583b1e92020333304b"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bf26c2e2ea59d32807081ad51968133af3025c4ba5753e6a794683d2c91bf6e"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99af961d72ac731aae2a1b55ccbdae0733d816f8bfb97b41909e143de735f522"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02906e7306cb8c5901a1feb61f9ab5e5c690dbbeaa04d84c1b9ae2a01ebe9379"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5362d099c244a2d2f9659fb3c9db7c735f0004765bbe06b99be69fbd87c3f15"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ac426704840877a285d03a445e162eb258924f014e2f074e209d9b4ff7bf380"}, + {file = "pydantic_core-2.16.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b94cbda27267423411c928208e89adddf2ea5dd5f74b9528513f0358bba019cb"}, + {file = "pydantic_core-2.16.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6db58c22ac6c81aeac33912fb1af0e930bc9774166cdd56eade913d5f2fff35e"}, + {file = "pydantic_core-2.16.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:396fdf88b1b503c9c59c84a08b6833ec0c3b5ad1a83230252a9e17b7dfb4cffc"}, + {file = "pydantic_core-2.16.2-cp312-none-win32.whl", hash = "sha256:7c31669e0c8cc68400ef0c730c3a1e11317ba76b892deeefaf52dcb41d56ed5d"}, + {file = "pydantic_core-2.16.2-cp312-none-win_amd64.whl", hash = "sha256:a3b7352b48fbc8b446b75f3069124e87f599d25afb8baa96a550256c031bb890"}, + {file = "pydantic_core-2.16.2-cp312-none-win_arm64.whl", hash = "sha256:a9e523474998fb33f7c1a4d55f5504c908d57add624599e095c20fa575b8d943"}, + {file = "pydantic_core-2.16.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ae34418b6b389d601b31153b84dce480351a352e0bb763684a1b993d6be30f17"}, + {file = "pydantic_core-2.16.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:732bd062c9e5d9582a30e8751461c1917dd1ccbdd6cafb032f02c86b20d2e7ec"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b52776a2e3230f4854907a1e0946eec04d41b1fc64069ee774876bbe0eab55"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ef551c053692b1e39e3f7950ce2296536728871110e7d75c4e7753fb30ca87f4"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ebb892ed8599b23fa8f1799e13a12c87a97a6c9d0f497525ce9858564c4575a4"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa6c8c582036275997a733427b88031a32ffa5dfc3124dc25a730658c47a572f"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4ba0884a91f1aecce75202473ab138724aa4fb26d7707f2e1fa6c3e68c84fbf"}, + {file = "pydantic_core-2.16.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7924e54f7ce5d253d6160090ddc6df25ed2feea25bfb3339b424a9dd591688bc"}, + {file = "pydantic_core-2.16.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69a7b96b59322a81c2203be537957313b07dd333105b73db0b69212c7d867b4b"}, + {file = "pydantic_core-2.16.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7e6231aa5bdacda78e96ad7b07d0c312f34ba35d717115f4b4bff6cb87224f0f"}, + {file = "pydantic_core-2.16.2-cp38-none-win32.whl", hash = "sha256:41dac3b9fce187a25c6253ec79a3f9e2a7e761eb08690e90415069ea4a68ff7a"}, + {file = "pydantic_core-2.16.2-cp38-none-win_amd64.whl", hash = "sha256:f685dbc1fdadb1dcd5b5e51e0a378d4685a891b2ddaf8e2bba89bd3a7144e44a"}, + {file = "pydantic_core-2.16.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:55749f745ebf154c0d63d46c8c58594d8894b161928aa41adbb0709c1fe78b77"}, + {file = "pydantic_core-2.16.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b30b0dd58a4509c3bd7eefddf6338565c4905406aee0c6e4a5293841411a1286"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18de31781cdc7e7b28678df7c2d7882f9692ad060bc6ee3c94eb15a5d733f8f7"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5864b0242f74b9dd0b78fd39db1768bc3f00d1ffc14e596fd3e3f2ce43436a33"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8f9186ca45aee030dc8234118b9c0784ad91a0bb27fc4e7d9d6608a5e3d386c"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc6f6c9be0ab6da37bc77c2dda5f14b1d532d5dbef00311ee6e13357a418e646"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa057095f621dad24a1e906747179a69780ef45cc8f69e97463692adbcdae878"}, + {file = "pydantic_core-2.16.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ad84731a26bcfb299f9eab56c7932d46f9cad51c52768cace09e92a19e4cf55"}, + {file = "pydantic_core-2.16.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3b052c753c4babf2d1edc034c97851f867c87d6f3ea63a12e2700f159f5c41c3"}, + {file = "pydantic_core-2.16.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e0f686549e32ccdb02ae6f25eee40cc33900910085de6aa3790effd391ae10c2"}, + {file = "pydantic_core-2.16.2-cp39-none-win32.whl", hash = "sha256:7afb844041e707ac9ad9acad2188a90bffce2c770e6dc2318be0c9916aef1469"}, + {file = "pydantic_core-2.16.2-cp39-none-win_amd64.whl", hash = "sha256:9da90d393a8227d717c19f5397688a38635afec89f2e2d7af0df037f3249c39a"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5f60f920691a620b03082692c378661947d09415743e437a7478c309eb0e4f82"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:47924039e785a04d4a4fa49455e51b4eb3422d6eaacfde9fc9abf8fdef164e8a"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6294e76b0380bb7a61eb8a39273c40b20beb35e8c87ee101062834ced19c545"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe56851c3f1d6f5384b3051c536cc81b3a93a73faf931f404fef95217cf1e10d"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9d776d30cde7e541b8180103c3f294ef7c1862fd45d81738d156d00551005784"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:72f7919af5de5ecfaf1eba47bf9a5d8aa089a3340277276e5636d16ee97614d7"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:4bfcbde6e06c56b30668a0c872d75a7ef3025dc3c1823a13cf29a0e9b33f67e8"}, + {file = "pydantic_core-2.16.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ff7c97eb7a29aba230389a2661edf2e9e06ce616c7e35aa764879b6894a44b25"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9b5f13857da99325dcabe1cc4e9e6a3d7b2e2c726248ba5dd4be3e8e4a0b6d0e"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a7e41e3ada4cca5f22b478c08e973c930e5e6c7ba3588fb8e35f2398cdcc1545"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60eb8ceaa40a41540b9acae6ae7c1f0a67d233c40dc4359c256ad2ad85bdf5e5"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7beec26729d496a12fd23cf8da9944ee338c8b8a17035a560b585c36fe81af20"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:22c5f022799f3cd6741e24f0443ead92ef42be93ffda0d29b2597208c94c3753"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:eca58e319f4fd6df004762419612122b2c7e7d95ffafc37e890252f869f3fb2a"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ed957db4c33bc99895f3a1672eca7e80e8cda8bd1e29a80536b4ec2153fa9804"}, + {file = "pydantic_core-2.16.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:459c0d338cc55d099798618f714b21b7ece17eb1a87879f2da20a3ff4c7628e2"}, + {file = "pydantic_core-2.16.2.tar.gz", hash = "sha256:0ba503850d8b8dcc18391f10de896ae51d37fe5fe43dbfb6a35c5c5cad271a06"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + [[package]] name = "pytest" version = "7.4.3" @@ -1951,6 +2276,21 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "referencing" +version = "0.33.0" +description = "JSON Referencing + Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "referencing-0.33.0-py3-none-any.whl", hash = "sha256:39240f2ecc770258f28b642dd47fd74bc8b02484de54e1882b74b35ebd779bd5"}, + {file = "referencing-0.33.0.tar.gz", hash = "sha256:c775fedf74bc0f9189c2a3be1c12fd03e8c23f4d371dce795df44e06c5b412f7"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" version = "2023.10.3" @@ -2069,6 +2409,114 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rpds-py" +version = "0.17.1" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = true +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.17.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:4128980a14ed805e1b91a7ed551250282a8ddf8201a4e9f8f5b7e6225f54170d"}, + {file = "rpds_py-0.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ff1dcb8e8bc2261a088821b2595ef031c91d499a0c1b031c152d43fe0a6ecec8"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d65e6b4f1443048eb7e833c2accb4fa7ee67cc7d54f31b4f0555b474758bee55"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a71169d505af63bb4d20d23a8fbd4c6ce272e7bce6cc31f617152aa784436f29"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:436474f17733c7dca0fbf096d36ae65277e8645039df12a0fa52445ca494729d"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10162fe3f5f47c37ebf6d8ff5a2368508fe22007e3077bf25b9c7d803454d921"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:720215373a280f78a1814becb1312d4e4d1077b1202a56d2b0815e95ccb99ce9"}, + {file = "rpds_py-0.17.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:70fcc6c2906cfa5c6a552ba7ae2ce64b6c32f437d8f3f8eea49925b278a61453"}, + {file = "rpds_py-0.17.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:91e5a8200e65aaac342a791272c564dffcf1281abd635d304d6c4e6b495f29dc"}, + {file = "rpds_py-0.17.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:99f567dae93e10be2daaa896e07513dd4bf9c2ecf0576e0533ac36ba3b1d5394"}, + {file = "rpds_py-0.17.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:24e4900a6643f87058a27320f81336d527ccfe503984528edde4bb660c8c8d59"}, + {file = "rpds_py-0.17.1-cp310-none-win32.whl", hash = "sha256:0bfb09bf41fe7c51413f563373e5f537eaa653d7adc4830399d4e9bdc199959d"}, + {file = "rpds_py-0.17.1-cp310-none-win_amd64.whl", hash = "sha256:20de7b7179e2031a04042e85dc463a93a82bc177eeba5ddd13ff746325558aa6"}, + {file = "rpds_py-0.17.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:65dcf105c1943cba45d19207ef51b8bc46d232a381e94dd38719d52d3980015b"}, + {file = "rpds_py-0.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:01f58a7306b64e0a4fe042047dd2b7d411ee82e54240284bab63e325762c1147"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:071bc28c589b86bc6351a339114fb7a029f5cddbaca34103aa573eba7b482382"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ae35e8e6801c5ab071b992cb2da958eee76340e6926ec693b5ff7d6381441745"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149c5cd24f729e3567b56e1795f74577aa3126c14c11e457bec1b1c90d212e38"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e796051f2070f47230c745d0a77a91088fbee2cc0502e9b796b9c6471983718c"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e820ee1004327609b28db8307acc27f5f2e9a0b185b2064c5f23e815f248f8"}, + {file = "rpds_py-0.17.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1957a2ab607f9added64478a6982742eb29f109d89d065fa44e01691a20fc20a"}, + {file = "rpds_py-0.17.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8587fd64c2a91c33cdc39d0cebdaf30e79491cc029a37fcd458ba863f8815383"}, + {file = "rpds_py-0.17.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4dc889a9d8a34758d0fcc9ac86adb97bab3fb7f0c4d29794357eb147536483fd"}, + {file = "rpds_py-0.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2953937f83820376b5979318840f3ee47477d94c17b940fe31d9458d79ae7eea"}, + {file = "rpds_py-0.17.1-cp311-none-win32.whl", hash = "sha256:1bfcad3109c1e5ba3cbe2f421614e70439f72897515a96c462ea657261b96518"}, + {file = "rpds_py-0.17.1-cp311-none-win_amd64.whl", hash = "sha256:99da0a4686ada4ed0f778120a0ea8d066de1a0a92ab0d13ae68492a437db78bf"}, + {file = "rpds_py-0.17.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1dc29db3900cb1bb40353772417800f29c3d078dbc8024fd64655a04ee3c4bdf"}, + {file = "rpds_py-0.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:82ada4a8ed9e82e443fcef87e22a3eed3654dd3adf6e3b3a0deb70f03e86142a"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d36b2b59e8cc6e576f8f7b671e32f2ff43153f0ad6d0201250a7c07f25d570e"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3677fcca7fb728c86a78660c7fb1b07b69b281964673f486ae72860e13f512ad"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:516fb8c77805159e97a689e2f1c80655c7658f5af601c34ffdb916605598cda2"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df3b6f45ba4515632c5064e35ca7f31d51d13d1479673185ba8f9fefbbed58b9"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a967dd6afda7715d911c25a6ba1517975acd8d1092b2f326718725461a3d33f9"}, + {file = "rpds_py-0.17.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dbbb95e6fc91ea3102505d111b327004d1c4ce98d56a4a02e82cd451f9f57140"}, + {file = "rpds_py-0.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:02866e060219514940342a1f84303a1ef7a1dad0ac311792fbbe19b521b489d2"}, + {file = "rpds_py-0.17.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:2528ff96d09f12e638695f3a2e0c609c7b84c6df7c5ae9bfeb9252b6fa686253"}, + {file = "rpds_py-0.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bd345a13ce06e94c753dab52f8e71e5252aec1e4f8022d24d56decd31e1b9b23"}, + {file = "rpds_py-0.17.1-cp312-none-win32.whl", hash = "sha256:2a792b2e1d3038daa83fa474d559acfd6dc1e3650ee93b2662ddc17dbff20ad1"}, + {file = "rpds_py-0.17.1-cp312-none-win_amd64.whl", hash = "sha256:292f7344a3301802e7c25c53792fae7d1593cb0e50964e7bcdcc5cf533d634e3"}, + {file = "rpds_py-0.17.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:8ffe53e1d8ef2520ebcf0c9fec15bb721da59e8ef283b6ff3079613b1e30513d"}, + {file = "rpds_py-0.17.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4341bd7579611cf50e7b20bb8c2e23512a3dc79de987a1f411cb458ab670eb90"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f4eb548daf4836e3b2c662033bfbfc551db58d30fd8fe660314f86bf8510b93"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b686f25377f9c006acbac63f61614416a6317133ab7fafe5de5f7dc8a06d42eb"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4e21b76075c01d65d0f0f34302b5a7457d95721d5e0667aea65e5bb3ab415c25"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b86b21b348f7e5485fae740d845c65a880f5d1eda1e063bc59bef92d1f7d0c55"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f175e95a197f6a4059b50757a3dca33b32b61691bdbd22c29e8a8d21d3914cae"}, + {file = "rpds_py-0.17.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1701fc54460ae2e5efc1dd6350eafd7a760f516df8dbe51d4a1c79d69472fbd4"}, + {file = "rpds_py-0.17.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:9051e3d2af8f55b42061603e29e744724cb5f65b128a491446cc029b3e2ea896"}, + {file = "rpds_py-0.17.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:7450dbd659fed6dd41d1a7d47ed767e893ba402af8ae664c157c255ec6067fde"}, + {file = "rpds_py-0.17.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:5a024fa96d541fd7edaa0e9d904601c6445e95a729a2900c5aec6555fe921ed6"}, + {file = "rpds_py-0.17.1-cp38-none-win32.whl", hash = "sha256:da1ead63368c04a9bded7904757dfcae01eba0e0f9bc41d3d7f57ebf1c04015a"}, + {file = "rpds_py-0.17.1-cp38-none-win_amd64.whl", hash = "sha256:841320e1841bb53fada91c9725e766bb25009cfd4144e92298db296fb6c894fb"}, + {file = "rpds_py-0.17.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:f6c43b6f97209e370124baf2bf40bb1e8edc25311a158867eb1c3a5d449ebc7a"}, + {file = "rpds_py-0.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7d63ec01fe7c76c2dbb7e972fece45acbb8836e72682bde138e7e039906e2c"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81038ff87a4e04c22e1d81f947c6ac46f122e0c80460b9006e6517c4d842a6ec"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:810685321f4a304b2b55577c915bece4c4a06dfe38f6e62d9cc1d6ca8ee86b99"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:25f071737dae674ca8937a73d0f43f5a52e92c2d178330b4c0bb6ab05586ffa6"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa5bfb13f1e89151ade0eb812f7b0d7a4d643406caaad65ce1cbabe0a66d695f"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dfe07308b311a8293a0d5ef4e61411c5c20f682db6b5e73de6c7c8824272c256"}, + {file = "rpds_py-0.17.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a000133a90eea274a6f28adc3084643263b1e7c1a5a66eb0a0a7a36aa757ed74"}, + {file = "rpds_py-0.17.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d0e8a6434a3fbf77d11448c9c25b2f25244226cfbec1a5159947cac5b8c5fa4"}, + {file = "rpds_py-0.17.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:efa767c220d94aa4ac3a6dd3aeb986e9f229eaf5bce92d8b1b3018d06bed3772"}, + {file = "rpds_py-0.17.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dbc56680ecf585a384fbd93cd42bc82668b77cb525343170a2d86dafaed2a84b"}, + {file = "rpds_py-0.17.1-cp39-none-win32.whl", hash = "sha256:270987bc22e7e5a962b1094953ae901395e8c1e1e83ad016c5cfcfff75a15a3f"}, + {file = "rpds_py-0.17.1-cp39-none-win_amd64.whl", hash = "sha256:2a7b2f2f56a16a6d62e55354dd329d929560442bd92e87397b7a9586a32e3e76"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a3264e3e858de4fc601741498215835ff324ff2482fd4e4af61b46512dd7fc83"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f2f3b28b40fddcb6c1f1f6c88c6f3769cd933fa493ceb79da45968a21dccc920"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9584f8f52010295a4a417221861df9bea4c72d9632562b6e59b3c7b87a1522b7"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c64602e8be701c6cfe42064b71c84ce62ce66ddc6422c15463fd8127db3d8066"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:060f412230d5f19fc8c8b75f315931b408d8ebf56aec33ef4168d1b9e54200b1"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9412abdf0ba70faa6e2ee6c0cc62a8defb772e78860cef419865917d86c7342"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9737bdaa0ad33d34c0efc718741abaafce62fadae72c8b251df9b0c823c63b22"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9f0e4dc0f17dcea4ab9d13ac5c666b6b5337042b4d8f27e01b70fae41dd65c57"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1db228102ab9d1ff4c64148c96320d0be7044fa28bd865a9ce628ce98da5973d"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:d8bbd8e56f3ba25a7d0cf980fc42b34028848a53a0e36c9918550e0280b9d0b6"}, + {file = "rpds_py-0.17.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:be22ae34d68544df293152b7e50895ba70d2a833ad9566932d750d3625918b82"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bf046179d011e6114daf12a534d874958b039342b347348a78b7cdf0dd9d6041"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a746a6d49665058a5896000e8d9d2f1a6acba8a03b389c1e4c06e11e0b7f40d"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b8bf5b8db49d8fd40f54772a1dcf262e8be0ad2ab0206b5a2ec109c176c0a4"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f7f4cb1f173385e8a39c29510dd11a78bf44e360fb75610594973f5ea141028b"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7fbd70cb8b54fe745301921b0816c08b6d917593429dfc437fd024b5ba713c58"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bdf1303df671179eaf2cb41e8515a07fc78d9d00f111eadbe3e14262f59c3d0"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fad059a4bd14c45776600d223ec194e77db6c20255578bb5bcdd7c18fd169361"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3664d126d3388a887db44c2e293f87d500c4184ec43d5d14d2d2babdb4c64cad"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:698ea95a60c8b16b58be9d854c9f993c639f5c214cf9ba782eca53a8789d6b19"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:c3d2010656999b63e628a3c694f23020322b4178c450dc478558a2b6ef3cb9bb"}, + {file = "rpds_py-0.17.1-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:938eab7323a736533f015e6069a7d53ef2dcc841e4e533b782c2bfb9fb12d84b"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1e626b365293a2142a62b9a614e1f8e331b28f3ca57b9f05ebbf4cf2a0f0bdc5"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:380e0df2e9d5d5d339803cfc6d183a5442ad7ab3c63c2a0982e8c824566c5ccc"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b760a56e080a826c2e5af09002c1a037382ed21d03134eb6294812dda268c811"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5576ee2f3a309d2bb403ec292d5958ce03953b0e57a11d224c1f134feaf8c40f"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f3c3461ebb4c4f1bbc70b15d20b565759f97a5aaf13af811fcefc892e9197ba"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:637b802f3f069a64436d432117a7e58fab414b4e27a7e81049817ae94de45d8d"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffee088ea9b593cc6160518ba9bd319b5475e5f3e578e4552d63818773c6f56a"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3ac732390d529d8469b831949c78085b034bff67f584559340008d0f6041a049"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:93432e747fb07fa567ad9cc7aaadd6e29710e515aabf939dfbed8046041346c6"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:7b7d9ca34542099b4e185b3c2a2b2eda2e318a7dbde0b0d83357a6d4421b5296"}, + {file = "rpds_py-0.17.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:0387ce69ba06e43df54e43968090f3626e231e4bc9150e4c3246947567695f68"}, + {file = "rpds_py-0.17.1.tar.gz", hash = "sha256:0210b2668f24c078307260bf88bdac9d6f1093635df5123789bfee4d8d7fc8e7"}, +] + [[package]] name = "s3transfer" version = "0.9.0" @@ -3122,6 +3570,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] accelerate = ["accelerate"] bnb = ["bitsandbytes"] +outlines = ["outlines"] peft = ["peft"] quantize = ["accelerate", "datasets", "hqq", "texttable"] torch = ["torch"] @@ -3129,4 +3578,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4520e26f84b492bd349447b8a8f1f090e55b321ec695eeb67de14874bb12f8a5" +content-hash = "e289f6159f0f16a53b558dd55d3262b8a354aafc338ed402f9ef78a31a20ec04" diff --git a/server/pyproject.toml b/server/pyproject.toml index d7e4bb722..d6ed4721b 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -38,6 +38,7 @@ boto3 = "^1.28.34" urllib3 = "<=1.26.18" hqq = { version = "^0.1.2", optional = true } stanford-stk = { version = "^0.7.0", markers = "sys_platform == 'linux'" } +outlines = { version = "^0.0.26", optional = true } [tool.poetry.extras] torch = ["torch"] @@ -45,6 +46,7 @@ accelerate = ["accelerate"] bnb = ["bitsandbytes"] peft = ["peft"] quantize = ["texttable", "datasets", "accelerate", "hqq"] +outlines = ["outlines"] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.51.1" From 85ff0a5897c09f8e59a66906d868cc3f8f1b40c1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 5 Feb 2024 15:25:29 -0800 Subject: [PATCH 04/25] Updated Python SDK --- clients/python/lorax/__init__.py | 2 +- clients/python/lorax/client.py | 17 ++++++++++++++++- clients/python/pyproject.toml | 2 +- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index b95c7ffbe..383c20c14 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.0" +__version__ = "0.3.1" from lorax.client import Client, AsyncClient, MergedAdapters diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 89fb956a0..b5aa2776c 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -3,7 +3,7 @@ from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator, Iterator +from typing import Any, Dict, Optional, List, AsyncIterator, Iterator from lorax.types import ( StreamResponse, @@ -79,6 +79,7 @@ def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + schema: Optional[Dict[str, Any]] = None, decoder_input_details: bool = False, ) -> Response: """ @@ -124,6 +125,8 @@ def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + schema (`Optional[Dict[str, Any]]`): + Optional JSON schema to validate the response decoder_input_details (`bool`): Return the decoder input token logprobs and ids @@ -150,6 +153,7 @@ def generate( truncate=truncate, typical_p=typical_p, watermark=watermark, + schema=json.dumps(schema) if schema is not None else None, decoder_input_details=decoder_input_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -185,6 +189,7 @@ def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + schema: Optional[Dict[str, Any]] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -227,6 +232,8 @@ def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + schema (`Optional[Dict[str, Any]]`): + Optional JSON schema to validate the response Returns: Iterator[StreamResponse]: stream of generated tokens @@ -252,6 +259,7 @@ def generate_stream( truncate=truncate, typical_p=typical_p, watermark=watermark, + schema=json.dumps(schema) if schema is not None else None, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -353,6 +361,7 @@ async def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + schema: Optional[Dict[str, Any]] = None, decoder_input_details: bool = False, ) -> Response: """ @@ -398,6 +407,8 @@ async def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + schema (`Optional[Dict[str, Any]]`): + Optional JSON schema to validate the response decoder_input_details (`bool`): Return the decoder input token logprobs and ids @@ -425,6 +436,7 @@ async def generate( truncate=truncate, typical_p=typical_p, watermark=watermark, + schema=json.dumps(schema) if schema is not None else None, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -457,6 +469,7 @@ async def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + schema: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -499,6 +512,8 @@ async def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + schema (`Optional[Dict[str, Any]]`): + Optional JSON schema to validate the response Returns: AsyncIterator[StreamResponse]: stream of generated tokens diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 8a5dadb4c..81b16d8b4 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lorax-client" packages = [ {include = "lorax"} ] -version = "0.3.0" +version = "0.3.1" description = "LoRAX Python Client" license = "Apache-2.0" authors = ["Travis Addair ", "Olivier Dehaene "] From 395607079ae7b8dcca0932d85a349209683859be Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 5 Feb 2024 15:29:01 -0800 Subject: [PATCH 05/25] OpenAPI --- docs/reference/openapi.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 3417b409b..574ad4d8c 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -829,6 +829,12 @@ "default": "false", "example": true }, + "schema": { + "type": "string", + "default": "null", + "example": "{\"type\": \"string\", \"title\": \"response\"}", + "nullable": true + }, "adapter_id": { "type": "string", "nullable": true From 40e62b54158811c195db73c0cb9863cd2afabe10 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Thu, 8 Feb 2024 11:48:54 -0600 Subject: [PATCH 06/25] try new params --- server/lorax_server/models/causal_lm.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 125 +++++++++--------- server/lorax_server/models/flash_mistral.py | 8 +- server/lorax_server/models/flash_mixtral.py | 8 +- server/lorax_server/models/seq2seq_lm.py | 8 +- server/lorax_server/models/types.py | 11 +- server/lorax_server/server.py | 2 +- server/lorax_server/utils/logits_process.py | 35 +++-- server/lorax_server/utils/tokens.py | 4 +- server/tests/models/test_bloom.py | 4 +- server/tests/models/test_causal_lm.py | 4 +- server/tests/models/test_seq2seq_lm.py | 4 +- 12 files changed, 127 insertions(+), 88 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 927e04c50..fd3e4c062 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -283,7 +283,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + def concatenate(cls, batches: List["CausalLMBatch"], tokenizers) -> "CausalLMBatch": # Used for padding total_batch_size = 0 max_input_length = 0 diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e8b4601e6..665594ae2 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1,39 +1,34 @@ -from collections import defaultdict -import json -import math import itertools -from loguru import logger -import torch -import torch.distributed +import math +from dataclasses import dataclass +from typing import Optional, Tuple, List, Type, Union, Dict import numpy as np - -from dataclasses import dataclass +import torch +import torch.distributed +from loguru import logger from opentelemetry import trace -from peft import LoraConfig from transformers import PreTrainedTokenizerBase -from typing import Optional, Set, Tuple, List, Type, Union, Dict from lorax_server.models import Model +from lorax_server.models.cache_manager import ( + get_cache_manager, + set_cache_manager, + BLOCK_SIZE, +) from lorax_server.models.types import ( Batch, PrefillTokens, Generation, GeneratedText, ) -from lorax_server.models.cache_manager import ( - get_cache_manager, - set_cache_manager, - BLOCK_SIZE, -) from lorax_server.pb import generate_pb2 from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser -from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map +from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.dist import MEMORY_FRACTION -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights -from lorax_server.utils.segments import SegmentConcatBuilder, find_segments -from lorax_server.utils.weights import shard_on_dim from lorax_server.utils.graph import GraphCache +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata +from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -150,7 +145,7 @@ def from_pb( next_token_chooser_parameters = [] stopping_criterias = [] - + adapter_indices_list = [] adapter_set = set() @@ -166,12 +161,12 @@ def from_pb( # Parse batch for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) + zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] + tokenized_input = tokenized_input[-r.truncate:] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -401,7 +396,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": adapter_set.add(self.requests[idx].adapter_index) remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) request_block_table = self.block_tables[idx] @@ -414,10 +409,10 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": # Set slice slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 + self.start_slots[idx]: self.start_slots[idx] + + request_input_length + + remaining_tokens + - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -490,7 +485,12 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate( + cls, + batches: List["FlashCausalLMBatch"], + tokenizer: PreTrainedTokenizerBase, + tokenizers: TokenizerManager + ) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -534,7 +534,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) total_indices_size = sum(b.adapter_meta.adapter_indices.shape[0] for b in batches) - + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(total_indices_size) adapter_set = set() adapter_segment_builder = SegmentConcatBuilder() @@ -588,11 +588,11 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch adapter_segment_builder.concat(batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices) all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] + start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] block_tables_tensor[ - start_index:end_index, : batch.block_tables_tensor.shape[1] + start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] start_slots.append(batch.start_slots + cumulative_slots) @@ -613,11 +613,18 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch start_slots = torch.concat(start_slots) + request_tokenizers = [ + tokenizers.get_tokenizer(r.adapter_index, tokenizer) + for r in requests + ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, + tokenizers=request_tokenizers, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, ) + # next_token_chooser.schema_processor = HeterogeneousSchemaLogitsProcessor.concatenate( + # [b.next_token_chooser.schema_processor for b in batches]) adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -743,7 +750,7 @@ def warmup(self, batch: FlashCausalLMBatch): if self.compile: if self.world_size > 1: raise ValueError("Cannot enable `--compile` when sharding across multiple GPUs") - + # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. @@ -768,9 +775,9 @@ def warmup(self, batch: FlashCausalLMBatch): ) num_blocks = ( - int(free_memory // total_cache_size) - # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + cache_manager.num_blocks + int(free_memory // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + cache_manager.num_blocks ) del batch @@ -805,9 +812,9 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> prefill = batch.cu_seqlen_prefill is not None model = self.model if ( - self.model_graph_wrapper is not None and - not prefill and - self.model_graph_wrapper.can_use_graph(batch, adapter_data) + self.model_graph_wrapper is not None and + not prefill and + self.model_graph_wrapper.can_use_graph(batch, adapter_data) ): model = self.model_graph_wrapper @@ -905,8 +912,8 @@ def generate_token( # For each member of the batch for i, ( - input_length, - all_input_ids, + input_length, + all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -931,13 +938,13 @@ def generate_token( if prefill_logprobs: if len(batch) > 1: prefill_tokens_indices[ - out_start_index : out_end_index - 1 - ] = batch.input_ids[start_index + 1 : start_index + out_length] + out_start_index: out_end_index - 1 + ] = batch.input_ids[start_index + 1: start_index + out_length] else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ - start_index + 1 : start_index + out_length - ] + start_index + 1: start_index + out_length + ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] @@ -954,8 +961,8 @@ def generate_token( # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, + adapter_segments, + dtype=torch.int32, device=batch.adapter_meta.adapter_segments.device, ) @@ -988,16 +995,16 @@ def generate_token( # For each member of the batch for i, ( - request, - input_length, - prefix_offset, - read_offset, - stopping_criteria, - all_input_ids, - do_sample, - seed, - next_token_id, - next_token_logprob, + request, + input_length, + prefix_offset, + read_offset, + stopping_criteria, + all_input_ids, + do_sample, + seed, + next_token_id, + next_token_logprob, ) in enumerate(iterator): # Append next token to all tokens all_input_ids.append(next_token_id) @@ -1024,7 +1031,7 @@ def generate_token( if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + all_input_ids[-stopping_criteria.current_tokens:] ) generated_text = GeneratedText( output_text, @@ -1042,8 +1049,8 @@ def generate_token( # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ - out_start_index : out_end_index - 1 - ] + out_start_index: out_end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 8fe38447a..aabda653d 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -204,9 +204,13 @@ def from_pb( max_length = max(max_length, input_length + max_new_tokens) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - + + request_tokenizers = [ + tokenizers.get_tokenizer(r.adapter_index, tokenizer) + for r in pb.requests + ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, request_tokenizers, dtype, device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index ea4c1ca99..675f78385 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -211,9 +211,13 @@ def from_pb( max_length = max(max_length, input_length + max_new_tokens) adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) - + + request_tokenizers = [ + tokenizers.get_tokenizer(r.adapter_index, tokenizer) + for r in pb.requests + ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device + next_token_chooser_parameters, request_tokenizers, dtype, device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 00db4dd2e..a18172ad4 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -267,8 +267,12 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": - """Concatenate multiple batches together by padding internal torch tensors""" + def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizers) -> "Seq2SeqLMBatch": + """Concatenate multiple batches together by padding internal torch tensors + + Args: + tokenizers: + """ # Used for padding total_batch_size = 0 diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index f22ff7da3..0752b2639 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -1,13 +1,13 @@ -import torch - from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional +import torch from transformers import PreTrainedTokenizerBase from lorax_server.pb import generate_pb2 from lorax_server.pb.generate_pb2 import FinishReason +from lorax_server.utils.tokenizer import TokenizerManager class Batch(ABC): @@ -32,7 +32,12 @@ def filter(self, request_ids: List[int]) -> "Batch": @classmethod @abstractmethod - def concatenate(cls, batches: List["Batch"]) -> "Batch": + def concatenate( + cls, + batches: List["Batch"], + tokenizer: PreTrainedTokenizerBase, + tokenizer_mgr: TokenizerManager + ) -> "Batch": raise NotImplementedError @abstractmethod diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 066120f3d..82ada7ff7 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -114,7 +114,7 @@ async def Decode(self, request, context): raise ValueError("All batches are empty") if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches) + batch = self.model.batch_type.concatenate(batches, self.model.tokenizer, self.model.tokenizers) else: batch = batches[0] diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 56d8ecda3..b13a39033 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -1,9 +1,8 @@ import math -import torch - from functools import lru_cache from typing import Optional, List, Dict, Union +import torch from transformers import ( LogitsWarper, LogitsProcessor, @@ -17,6 +16,7 @@ try: from outlines.fsm.fsm import RegexFSM from outlines.fsm.json_schema import build_regex_from_object + HAS_OUTLINES = True except ImportError: HAS_OUTLINES = False @@ -198,7 +198,7 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = probs <= self.top_p_opposite # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( @@ -403,7 +403,7 @@ def __init__( def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: for i, processor in self.processors.items(): - scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1]) + scores[i: i + 1] = processor(input_ids[i: i + 1], scores[i: i + 1]) return scores def filter(self, indices): @@ -432,9 +432,14 @@ class HeterogeneousSchemaLogitsProcessor(LogitsProcessor): def __init__( self, - schemas: List[Optional[str]], - tokenizers: List[Optional[PreTrainedTokenizerBase]], + schemas: List[Optional[str]] = None, + tokenizers: List[Optional[PreTrainedTokenizerBase]] = None, ): + if schemas is None: + schemas = [] + if tokenizers is None: + tokenizers = [] + self.sequence_processors = [ None if schema is None else OutlinesLogitsProcessor(schema, tokenizer) for schema, tokenizer in zip(schemas, tokenizers) @@ -452,6 +457,16 @@ def filter(self, indices): return self return None + # @classmethod + # def concatenate( + # cls, + # processors: List["HeterogeneousSchemaLogitsProcessor"] + # ) -> "HeterogeneousSchemaLogitsProcessor": + # ret = HeterogeneousSchemaLogitsProcessor() + # for p in processors: + # ret.sequence_processors.extend(p.sequence_processors) + # return ret + # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py class OutlinesLogitsProcessor: @@ -467,7 +482,7 @@ def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): if not HAS_OUTLINES: raise ImportError("Unable to enforce JSON schema: `outlines` is not installed.") - tokenizer = self.adapt_tokenizer(tokenizer) + self.tokenizer = self.adapt_tokenizer(tokenizer) regex_string = build_regex_from_object(schema) self.fsm = RegexFSM(regex_string, tokenizer) @@ -487,17 +502,17 @@ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tenso return biased_scores - def adapt_tokenizer(self, tokenizer): + def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): """Adapt vLLM's tokenizer to use to compile the FSM. The API of Outlines tokenizers is slightly different to that of - `transformers`. In addition we need to handle the missing spaces to + `transformers`. In addition, we need to handle the missing spaces to Llama's tokenizer to be able to compile FSMs for this model. """ if hasattr(tokenizer, "vocabulary"): # We've already adapted the tokenizer from a previous request return tokenizer - + tokenizer.vocabulary = tokenizer.get_vocab() tokenizer.special_tokens = set(tokenizer.all_special_tokens) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 810a8fc75..4cb441240 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -201,7 +201,7 @@ class HeterogeneousNextTokenChooser: watermark (List[bool]): A list of booleans indicating whether watermark processing should be applied for each token. temperature (List[float]): A list of temperature values for temperature-based logits warping. repetition_penalty (List[float]): A list of repetition penalty values for repetition penalty-based logits warping. - schema (List[str]): A list of JSON schema strings for Outlines logits warping. + schemas (List[str]): A list of JSON schema strings for Outlines logits warping. top_k (List[int]): A list of top-k values for top-k-based logits warping. top_p (List[float]): A list of top-p values for top-p-based logits warping. typical_p (List[float]): A list of typical-p values for typical-p-based logits warping. @@ -391,7 +391,7 @@ def from_pb( typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], - tokenizer=tokenizers, + tokenizers=tokenizers, device=device, dtype=dtype, ) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 88e4ed5eb..bb69a0f9a 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): def test_batch_concatenate_no_prefill(default_bloom_batch): with pytest.raises(ValueError): - BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch]) + BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch], []) def test_causal_lm_batch_type(default_bloom): @@ -228,7 +228,7 @@ def test_batch_concatenate( (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values ] - next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1], []) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 9ce4bd813..0fa95381d 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -91,7 +91,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): def test_batch_concatenate_no_prefill(default_causal_lm_batch): with pytest.raises(ValueError): - CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) + CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch], []) def test_causal_lm_batch_type(default_causal_lm): @@ -226,7 +226,7 @@ def test_batch_concatenate( (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values ] - next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1], []) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 4c5f15b25..ee76382ab 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch): with pytest.raises(ValueError): - Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch]) + Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch], []) def test_seq2seq_lm_batch_type(default_seq2seq_lm): @@ -236,7 +236,7 @@ def test_batch_concatenate( [t.clone() for t in layer] for layer in next_batch_1.past_key_values ] - next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1], []) assert next_batch.batch_id == 0 From 857c3634d32d09d8a23f4cecf8c9faf023a231d8 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Thu, 8 Feb 2024 16:37:59 -0600 Subject: [PATCH 07/25] basic path working --- server/lorax_server/utils/logits_process.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index b13a39033..16d7352da 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -14,7 +14,7 @@ ) try: - from outlines.fsm.fsm import RegexFSM + from outlines.fsm.fsm import RegexFSM, FSMState from outlines.fsm.json_schema import build_regex_from_object HAS_OUTLINES = True @@ -486,13 +486,19 @@ def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): regex_string = build_regex_from_object(schema) self.fsm = RegexFSM(regex_string, tokenizer) - self.fsm_state = 0 + + self.fsm_state = FSMState(0) + self.is_first_token = True def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - last_token = input_ids[-1] - self.fsm_state = self.fsm.next_state(self.fsm_state, last_token) + if self.is_first_token: + # For the very first token generated, we want to select the allowed tokens from the FSM's initial state. + self.is_first_token = False + else: + last_token = input_ids[0][-1].item() + self.fsm_state = self.fsm.next_state(self.fsm_state, last_token) allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state) From 066f7a7563befaa53c328228b443a9eb485623e2 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Thu, 8 Feb 2024 16:45:42 -0600 Subject: [PATCH 08/25] remove old implementation --- server/lorax_server/utils/logits_process.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 16d7352da..00d083406 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -457,16 +457,6 @@ def filter(self, indices): return self return None - # @classmethod - # def concatenate( - # cls, - # processors: List["HeterogeneousSchemaLogitsProcessor"] - # ) -> "HeterogeneousSchemaLogitsProcessor": - # ret = HeterogeneousSchemaLogitsProcessor() - # for p in processors: - # ret.sequence_processors.extend(p.sequence_processors) - # return ret - # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py class OutlinesLogitsProcessor: From 262aef1f76714352fc1acb262c3f5871609f11c1 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 14:02:53 -0600 Subject: [PATCH 09/25] concat schema processor directly --- server/lorax_server/models/causal_lm.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 18 +++++------------- server/lorax_server/models/seq2seq_lm.py | 2 +- server/lorax_server/models/types.py | 11 +++-------- server/lorax_server/utils/logits_process.py | 10 ++++++++++ server/tests/models/test_bloom.py | 4 ++-- server/tests/models/test_causal_lm.py | 4 ++-- server/tests/models/test_seq2seq_lm.py | 4 ++-- 8 files changed, 26 insertions(+), 29 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index fd3e4c062..927e04c50 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -283,7 +283,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["CausalLMBatch"], tokenizers) -> "CausalLMBatch": + def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Used for padding total_batch_size = 0 max_input_length = 0 diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 665594ae2..a93c2db3b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -27,6 +27,7 @@ from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID from lorax_server.utils.dist import MEMORY_FRACTION from lorax_server.utils.graph import GraphCache +from lorax_server.utils.logits_process import HeterogeneousSchemaLogitsProcessor from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments from lorax_server.utils.tokenizer import TokenizerManager @@ -485,12 +486,7 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": @classmethod @tracer.start_as_current_span("concatenate") - def concatenate( - cls, - batches: List["FlashCausalLMBatch"], - tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager - ) -> "FlashCausalLMBatch": + def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -613,18 +609,14 @@ def concatenate( start_slots = torch.concat(start_slots) - request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) - for r in requests - ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, - tokenizers=request_tokenizers, + tokenizers=[], dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, ) - # next_token_chooser.schema_processor = HeterogeneousSchemaLogitsProcessor.concatenate( - # [b.next_token_chooser.schema_processor for b in batches]) + next_token_chooser.schema_processor = HeterogeneousSchemaLogitsProcessor.concatenate( + [b.next_token_chooser.schema_processor for b in batches]) adapter_segments, adapter_segment_indices = adapter_segment_builder.build() diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index a18172ad4..22846d09b 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -267,7 +267,7 @@ def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]: @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["Seq2SeqLMBatch"], tokenizers) -> "Seq2SeqLMBatch": + def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch": """Concatenate multiple batches together by padding internal torch tensors Args: diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 0752b2639..f22ff7da3 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -1,13 +1,13 @@ +import torch + from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional -import torch from transformers import PreTrainedTokenizerBase from lorax_server.pb import generate_pb2 from lorax_server.pb.generate_pb2 import FinishReason -from lorax_server.utils.tokenizer import TokenizerManager class Batch(ABC): @@ -32,12 +32,7 @@ def filter(self, request_ids: List[int]) -> "Batch": @classmethod @abstractmethod - def concatenate( - cls, - batches: List["Batch"], - tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager - ) -> "Batch": + def concatenate(cls, batches: List["Batch"]) -> "Batch": raise NotImplementedError @abstractmethod diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 00d083406..177f4bd8b 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -457,6 +457,16 @@ def filter(self, indices): return self return None + @classmethod + def concatenate( + cls, + processors: List["HeterogeneousSchemaLogitsProcessor"] + ) -> "HeterogeneousSchemaLogitsProcessor": + ret = HeterogeneousSchemaLogitsProcessor() + for p in processors: + ret.sequence_processors.extend(p.sequence_processors) + return ret + # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py class OutlinesLogitsProcessor: diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index bb69a0f9a..88e4ed5eb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch): def test_batch_concatenate_no_prefill(default_bloom_batch): with pytest.raises(ValueError): - BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch], []) + BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch]) def test_causal_lm_batch_type(default_bloom): @@ -228,7 +228,7 @@ def test_batch_concatenate( (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values ] - next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1], []) + next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 0fa95381d..9ce4bd813 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -91,7 +91,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): def test_batch_concatenate_no_prefill(default_causal_lm_batch): with pytest.raises(ValueError): - CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch], []) + CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) def test_causal_lm_batch_type(default_causal_lm): @@ -226,7 +226,7 @@ def test_batch_concatenate( (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values ] - next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1], []) + next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1]) assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0]) assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0]) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index ee76382ab..4c5f15b25 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -95,7 +95,7 @@ def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch): def test_batch_concatenate_no_prefill(default_seq2seq_lm_batch): with pytest.raises(ValueError): - Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch], []) + Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch]) def test_seq2seq_lm_batch_type(default_seq2seq_lm): @@ -236,7 +236,7 @@ def test_batch_concatenate( [t.clone() for t in layer] for layer in next_batch_1.past_key_values ] - next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1], []) + next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1]) assert next_batch.batch_id == 0 From 31f1ca35f35dcecbc8d95652032b63bcce6ead3b Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 14:07:32 -0600 Subject: [PATCH 10/25] update batch from_pb sig --- server/lorax_server/models/flash_causal_lm.py | 6 +++--- server/lorax_server/models/flash_mistral.py | 6 +++--- server/lorax_server/models/flash_mixtral.py | 6 +++--- server/lorax_server/models/types.py | 2 ++ 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index a93c2db3b..cabf94570 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -111,14 +111,14 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager, + tokenizer_mgr: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizers.get_inputs(r, tokenizer) + inputs = tokenizer_mgr.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -240,7 +240,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) + tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index aabda653d..2feb9bda8 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -57,7 +57,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager, + tokenizer_mgr: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -67,7 +67,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizers.get_inputs(r, tokenizer) + inputs = tokenizer_mgr.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -206,7 +206,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) + tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 675f78385..33ed124cf 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -64,7 +64,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizers: TokenizerManager, + tokenizer_mgr: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -74,7 +74,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizers.get_inputs(r, tokenizer) + inputs = tokenizer_mgr.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -213,7 +213,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizers.get_tokenizer(r.adapter_index, tokenizer) + tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index f22ff7da3..e352c7ca4 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -8,6 +8,7 @@ from lorax_server.pb import generate_pb2 from lorax_server.pb.generate_pb2 import FinishReason +from lorax_server.utils.tokenizer import TokenizerManager class Batch(ABC): @@ -21,6 +22,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, + tokenizer_mgr: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "Batch": From 29573cea29b907ed61e474ea38850513f945b084 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 14:35:15 -0600 Subject: [PATCH 11/25] revert tokenizer_mgr rename and add handling for causallm --- server/lorax_server/models/causal_lm.py | 98 ++++++++++--------- server/lorax_server/models/flash_causal_lm.py | 6 +- server/lorax_server/models/flash_mistral.py | 6 +- server/lorax_server/models/flash_mixtral.py | 6 +- server/lorax_server/models/types.py | 2 +- server/lorax_server/utils/logits_process.py | 8 +- server/lorax_server/utils/tokens.py | 36 +++++-- 7 files changed, 90 insertions(+), 72 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 927e04c50..a99721a8b 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -1,12 +1,9 @@ -from collections import defaultdict -import json -import torch -import inspect - from dataclasses import dataclass +from typing import Optional, Tuple, List, Type, Dict + +import torch from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict from lorax_server.models import Model from lorax_server.models.types import ( @@ -17,9 +14,9 @@ ) from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -95,7 +92,10 @@ def from_pb( requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(req_inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, + device, + tokenizers.get_tokenizer(adapter_indices_list[i], + tokenizer))) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -212,7 +212,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max( @@ -226,12 +226,14 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: position_ids = self.position_ids[keep_indices] adapter_indices = self.adapter_meta.adapter_indices[keep_indices] self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] + keep_indices, + -(self.padding_right_offset + max_input_length): ( + self.attention_mask.shape[ + 1] - + self.padding_right_offset + ) + + new_padding_right_offset, + ] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: @@ -371,17 +373,17 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset ) attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, + start_index:end_index, + left_offset:-padding_right_offset, ] = batch.attention_mask[ :, - batch_left_offset : -batch.padding_right_offset, - ] + batch_left_offset: -batch.padding_right_offset, + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -405,7 +407,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length + max_input_length - batch.max_input_length ) * len(batch) start_index = end_index @@ -447,12 +449,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : + start_index:end_index, :, -past_seq_len:, : ] = past_keys[:, :, -past_seq_len:, :] else: # BLOOM case padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: + start_index:end_index, :, :, -past_seq_len: ] = past_keys[:, :, :, -past_seq_len:] del past_keys @@ -472,7 +474,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 padded_past_values[ - start_index:end_index, :, -past_seq_len:, : + start_index:end_index, :, -past_seq_len:, : ] = past_values[:, :, -past_seq_len:, :] del past_values @@ -525,7 +527,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with CausalLM") - + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype @@ -580,7 +582,7 @@ def __init__( @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch - + @property def has_adapter_data(self) -> bool: return False @@ -592,19 +594,19 @@ def decode(self, generated_ids: List[int]) -> str: def forward( self, - input_ids, - attention_mask, - position_ids, - past_key_values: Optional = None, + input_ids, + attention_mask, + position_ids, + past_key_values: Optional = None, adapter_data: Optional[AdapterBatchData] = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, + "input_ids" : input_ids, + "attention_mask" : attention_mask, "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, + "use_cache" : True, + "return_dict" : True, } if self.has_position_ids: kwargs["position_ids"] = position_ids @@ -653,14 +655,14 @@ def generate_token( # For each member of the batch for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -693,7 +695,7 @@ def generate_token( if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + all_input_ids[-stopping_criteria.current_tokens:, 0] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -713,8 +715,8 @@ def generate_token( prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() + -new_input_length:-1 + ].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index cabf94570..a93c2db3b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -111,14 +111,14 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -240,7 +240,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 2feb9bda8..aabda653d 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -57,7 +57,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -67,7 +67,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -206,7 +206,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 33ed124cf..675f78385 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -64,7 +64,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -74,7 +74,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -213,7 +213,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index e352c7ca4..610c41c6c 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -22,7 +22,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "Batch": diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 177f4bd8b..8b6d53798 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -432,8 +432,8 @@ class HeterogeneousSchemaLogitsProcessor(LogitsProcessor): def __init__( self, - schemas: List[Optional[str]] = None, - tokenizers: List[Optional[PreTrainedTokenizerBase]] = None, + schemas: Optional[List[Optional[str]]] = None, + tokenizers: Optional[List[Optional[PreTrainedTokenizerBase]]] = None, ): if schemas is None: schemas = [] @@ -441,7 +441,7 @@ def __init__( tokenizers = [] self.sequence_processors = [ - None if schema is None else OutlinesLogitsProcessor(schema, tokenizer) + None if schema is None or tokenizer is None else OutlinesLogitsProcessor(schema, tokenizer) for schema, tokenizer in zip(schemas, tokenizers) ] @@ -469,7 +469,7 @@ def concatenate( # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py -class OutlinesLogitsProcessor: +class OutlinesLogitsProcessor(LogitsProcessor): def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): """Compile the FSM that drives the regex-guided generation. diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 4cb441240..f9258fa24 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -18,7 +18,7 @@ HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousProcessorWrapper, - HeterogeneousSchemaLogitsProcessor, + HeterogeneousSchemaLogitsProcessor, OutlinesLogitsProcessor, ) @@ -30,12 +30,14 @@ class NextTokenChooser: watermark (bool): Whether to apply watermark processing to logits. Default is False. temperature (float): The temperature value for warping logits. Default is 1.0. repetition_penalty (float): The penalty value for repetition in logits. Default is 1.0. + schema (str): A JSON schema string for Outlines logits warping. top_k (int): The value for top-k warping of logits. Default is None. top_p (float): The value for top-p warping of logits. Default is None. typical_p (float): The value for typical-p warping of logits. Default is None. do_sample (bool): Whether to perform sampling. Default is False. seed (int): The seed value for random number generation. Default is 0. device (str): The device to use for computation. Default is "cpu". + tokenizer (PreTrainedTokenizerBase): A tokenizer to use for processing the tokens. Returns: next_id (torch.Tensor): The next token ID. @@ -44,15 +46,17 @@ class NextTokenChooser: def __init__( self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + schema: str = None, + top_k: int = None, + top_p: float = None, + typical_p: float = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -63,6 +67,12 @@ def __init__( else None ) + self.schema_processor = ( + OutlinesLogitsProcessor(schema, tokenizer) + if schema is not None and tokenizer is not None + else None + ) + has_warpers = ( (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) @@ -84,6 +94,8 @@ def __call__(self, input_ids, scores): scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.schema_processor is not None: + scores = self.schema_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -99,6 +111,7 @@ def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": """ Create a NextTokenChooser instance from a protobuf message. @@ -106,6 +119,7 @@ def from_pb( Args: pb (generate_pb2.NextTokenChooserParameters): The protobuf message containing the parameters. device (torch.device): The device to use for computation. + tokenizer (PreTrainedTokenizerBase): A tokenizer for use in processing the tokens. Returns: NextTokenChooser: The NextTokenChooser instance. @@ -114,12 +128,14 @@ def from_pb( watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, + schema=pb.schema, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p, do_sample=pb.do_sample, seed=pb.seed, device=device, + tokenizer=tokenizer, ) From 903895c7446a0a71077907222ec058cf17af4009 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 14:41:34 -0600 Subject: [PATCH 12/25] pass tokenizer and enable schema in other lms --- server/lorax_server/models/causal_lm.py | 5 +---- server/lorax_server/models/galactica.py | 2 +- server/lorax_server/models/seq2seq_lm.py | 2 +- server/lorax_server/utils/tokens.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index a99721a8b..bb3a0248d 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -92,10 +92,7 @@ def from_pb( requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(req_inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, - device, - tokenizers.get_tokenizer(adapter_indices_list[i], - tokenizer))) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/lorax_server/models/galactica.py b/server/lorax_server/models/galactica.py index 77a8730bb..be43d0244 100644 --- a/server/lorax_server/models/galactica.py +++ b/server/lorax_server/models/galactica.py @@ -95,7 +95,7 @@ def from_pb( # Add escape_custom_split_sequence to the CausalLMBatch logic req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(escape_custom_split_sequence(req_inputs)) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index 22846d09b..23829a394 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -96,7 +96,7 @@ def from_pb( inputs.append(req_inputs) requests_idx_mapping[r.id] = i decoder_input_lengths.append(1) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index f9258fa24..421308c38 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -111,7 +111,7 @@ def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, - tokenizer: PreTrainedTokenizerBase, + tokenizer: Optional[PreTrainedTokenizerBase], ) -> "NextTokenChooser": """ Create a NextTokenChooser instance from a protobuf message. From ee032175022587b0969ce2d74531d9b8a9c29051 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 14:52:21 -0600 Subject: [PATCH 13/25] bump rust to 1.74.0 --- .github/workflows/tests.yaml | 2 +- rust-toolchain.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 1ea61d715..48bc85c07 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,7 +33,7 @@ jobs: - name: Install Rust uses: actions-rs/toolchain@v1 with: - toolchain: 1.70.0 + toolchain: 1.74.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 2db1883c8..4bf510f12 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.70.0" +channel = "1.74.0" components = ["rustfmt", "clippy"] \ No newline at end of file From 29f5824af1ad5ff3eb818fa80248893edc1c2107 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 15:23:48 -0600 Subject: [PATCH 14/25] check schema and tokenizer via truthiness --- server/lorax_server/utils/logits_process.py | 2 +- server/lorax_server/utils/tokens.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 8b6d53798..24a6ee0cd 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -441,7 +441,7 @@ def __init__( tokenizers = [] self.sequence_processors = [ - None if schema is None or tokenizer is None else OutlinesLogitsProcessor(schema, tokenizer) + OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None for schema, tokenizer in zip(schemas, tokenizers) ] diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 421308c38..7a8f53487 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -69,7 +69,7 @@ def __init__( self.schema_processor = ( OutlinesLogitsProcessor(schema, tokenizer) - if schema is not None and tokenizer is not None + if schema and tokenizer else None ) From 9c874fc15972d20d6f07ceaa8fb3f1a530f9005a Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 15:41:12 -0600 Subject: [PATCH 15/25] add rust backtrace --- .github/workflows/tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 48bc85c07..412aad6e7 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -22,6 +22,7 @@ jobs: env: SCCACHE_GHA_ENABLED: "on" RUSTC_WRAPPER: /usr/local/bin/sccache + RUST_BACKTRACE: 1 SCCACHE: 0.3.3 steps: From 57810733fcefc3a4e639238714cab829bf8cd436 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 15:42:43 -0600 Subject: [PATCH 16/25] add outlines to make gen-server --- server/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/Makefile b/server/Makefile index b757e884d..0d9973901 100644 --- a/server/Makefile +++ b/server/Makefile @@ -20,7 +20,7 @@ install: gen-server pip install pip --upgrade pip install torch==2.2.0 pip install -r requirements.txt - pip install -e ".[bnb, accelerate, quantize, peft]" + pip install -e ".[bnb, accelerate, quantize, peft, outlines]" run-dev: # SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve gpt2 From df88b16b61e6f60d8cd034861e46e8e7d90f36bd Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 15:54:06 -0600 Subject: [PATCH 17/25] pass schema in test --- server/tests/conftest.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 7d77b6fe0..59e0e2bf2 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -5,6 +5,25 @@ @pytest.fixture def default_pb_parameters(): + schema = """{ + "$defs": { + "Armor": { + "enum": ["leather", "chainmail", "plate"], + "title": "Armor", + "type": "string" + } + }, + "properties": { + "name": {"maxLength": 10, "title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + "armor": {"$ref": "#/$defs/Armor"}, + "strength": {"title": "Strength", "type": "integer"}\ + }, + "required": ["name", "age", "armor", "strength"], + "title": "Character", + "type": "object" +}""" + return generate_pb2.NextTokenChooserParameters( temperature=1.0, repetition_penalty=1.0, @@ -12,6 +31,7 @@ def default_pb_parameters(): top_p=1.0, typical_p=1.0, do_sample=False, + schema=schema, ) From 48645395203e78388794c9734602c76cf78a8ef8 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 17:27:35 -0600 Subject: [PATCH 18/25] new fixtures --- server/tests/conftest.py | 29 +++++++--- server/tests/models/test_causal_lm.py | 77 ++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 8 deletions(-) diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 59e0e2bf2..4173ade2f 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -5,7 +5,24 @@ @pytest.fixture def default_pb_parameters(): - schema = """{ + return generate_pb2.NextTokenChooserParameters( + temperature=1.0, + repetition_penalty=1.0, + top_k=0, + top_p=1.0, + typical_p=1.0, + do_sample=False, + ) + + +@pytest.fixture +def default_pb_stop_parameters(): + return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) + + +@pytest.fixture +def default_json_schema(): + return """{ "$defs": { "Armor": { "enum": ["leather", "chainmail", "plate"], @@ -24,6 +41,9 @@ def default_pb_parameters(): "type": "object" }""" + +@pytest.fixture +def schema_constrained_pb_parameters(default_json_schema): return generate_pb2.NextTokenChooserParameters( temperature=1.0, repetition_penalty=1.0, @@ -31,10 +51,5 @@ def default_pb_parameters(): top_p=1.0, typical_p=1.0, do_sample=False, - schema=schema, + schema=default_json_schema, ) - - -@pytest.fixture -def default_pb_stop_parameters(): - return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 9ce4bd813..4b94f30cc 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -59,7 +59,82 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer): ) -def test_batch_from_pb(default_pb_batch, default_causal_lm_batch): +@pytest.fixture +def schema_constrained_pb_request(schema_constrained_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=100, + parameters=schema_constrained_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def schema_constrained_pb_batch(schema_constrained_pb_request): + return generate_pb2.Batch(id=0, requests=[schema_constrained_pb_request], size=1) + + +@pytest.fixture +def schema_constrained_causal_lm_batch(schema_constrained_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb( + schema_constrained_pb_batch, gpt2_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") + ) + + +@pytest.fixture +def schema_constrained_multi_requests_causal_lm_batch(schema_constrained_pb_request, gpt2_tokenizer): + req_0 = copy(default_pb_request) + req_0.id = 1 + req_1 = default_pb_request + req_1.id = 2 + req_1.stopping_parameters.max_new_tokens = 5 + + batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2) + return CausalLMBatch.from_pb( + batch_pb, gpt2_tokenizer, TokenizerManager(), torch.float32, torch.device("cpu") + ) + + +@pytest.mark.parametrize("pb_batch, causal_lm_batch", [ + ("default_pb_batch", "default_causal_lm_batch"), + ("schema_constrained_pb_batch", "schema_constrained_causal_lm_batch") +]) +def test_batch_from_pb(pb_batch, causal_lm_batch, request): + pb_batch = request.getfixturevalue(pb_batch) + causal_lm_batch = request.getfixturevalue(causal_lm_batch) + + batch = causal_lm_batch + + assert batch.batch_id == pb_batch.id + assert batch.requests == pb_batch.requests + + assert len(batch.input_ids) == pb_batch.size + assert batch.input_ids[0][-1] == 14402 + assert torch.all(batch.input_ids[0][:-1] == 50256) + + assert batch.attention_mask[0, 0] == 1 + assert torch.all(batch.attention_mask[0, 1:] == 0) + + assert batch.past_key_values is None + + assert all( + [ + torch.equal(input_ids, all_input_ids[:, 0]) + for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) + ] + ) + + assert batch.input_lengths == [1] + + assert len(batch) == pb_batch.size + assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) + + assert batch.max_input_length == batch.input_lengths[0] + + +def test_batch_with_schema_from_pb(default_pb_batch, default_causal_lm_batch): batch = default_causal_lm_batch assert batch.batch_id == default_pb_batch.id From 2c13886ce7e29d94b4f488ff468b1d56ef1c6def Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 17:40:35 -0600 Subject: [PATCH 19/25] remove paste, parametrize geen token --- server/tests/models/test_causal_lm.py | 42 ++++++--------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 4b94f30cc..1e0796194 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -134,36 +134,6 @@ def test_batch_from_pb(pb_batch, causal_lm_batch, request): assert batch.max_input_length == batch.input_lengths[0] -def test_batch_with_schema_from_pb(default_pb_batch, default_causal_lm_batch): - batch = default_causal_lm_batch - - assert batch.batch_id == default_pb_batch.id - assert batch.requests == default_pb_batch.requests - - assert len(batch.input_ids) == default_pb_batch.size - assert batch.input_ids[0][-1] == 14402 - assert torch.all(batch.input_ids[0][:-1] == 50256) - - assert batch.attention_mask[0, 0] == 1 - assert torch.all(batch.attention_mask[0, 1:] == 0) - - assert batch.past_key_values is None - - assert all( - [ - torch.equal(input_ids, all_input_ids[:, 0]) - for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids) - ] - ) - - assert batch.input_lengths == [1] - - assert len(batch) == default_pb_batch.size - assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch) - - assert batch.max_input_length == batch.input_lengths[0] - - def test_batch_concatenate_no_prefill(default_causal_lm_batch): with pytest.raises(ValueError): CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch]) @@ -173,9 +143,15 @@ def test_causal_lm_batch_type(default_causal_lm): assert default_causal_lm.batch_type == CausalLMBatch -def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): - sequence_length = len(default_causal_lm_batch.all_input_ids[0]) - generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) +@pytest.mark.parametrize("causal_lm_batch", [ + "default_causal_lm_batch", + "schema_constrained_causal_lm_batch", +]) +def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request): + causal_lm_batch = request.getfixturevalue(causal_lm_batch) + + sequence_length = len(causal_lm_batch.all_input_ids[0]) + generations, next_batch = default_causal_lm.generate_token(causal_lm_batch) assert len(generations) == len(next_batch) assert isinstance(next_batch, CausalLMBatch) From ad7861673b72ef973c75fa906c418768cd585f3a Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 18:01:08 -0600 Subject: [PATCH 20/25] fix one token ID --- server/tests/models/test_causal_lm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 1e0796194..091add5a8 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -143,11 +143,11 @@ def test_causal_lm_batch_type(default_causal_lm): assert default_causal_lm.batch_type == CausalLMBatch -@pytest.mark.parametrize("causal_lm_batch", [ - "default_causal_lm_batch", - "schema_constrained_causal_lm_batch", +@pytest.mark.parametrize("causal_lm_batch, generated_token_id", [ + ("default_causal_lm_batch", 13), + ("schema_constrained_causal_lm_batch", 90), ]) -def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request): +def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_token_id, request): causal_lm_batch = request.getfixturevalue(causal_lm_batch) sequence_length = len(causal_lm_batch.all_input_ids[0]) @@ -159,7 +159,10 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request): assert len(next_batch.all_input_ids) == len(next_batch) assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.attention_mask[0]) == 11 - assert next_batch.all_input_ids[0][-1] == 13 + assert next_batch.all_input_ids[0][-1] == generated_token_id + + print(f"\n\ngen_token: {default_causal_lm.tokenizer.decode(next_batch.all_input_ids[0][-1])}") + assert next_batch.all_input_ids[0][-2] == 14402 assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) From 71de5869bd8c6ed9ca88ca61440f1a263937eb4f Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 18:13:17 -0600 Subject: [PATCH 21/25] fix next token ref --- server/tests/models/test_causal_lm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 091add5a8..b2e7deed0 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -160,9 +160,6 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_ assert len(next_batch.all_input_ids[0]) == sequence_length + 1 assert len(next_batch.attention_mask[0]) == 11 assert next_batch.all_input_ids[0][-1] == generated_token_id - - print(f"\n\ngen_token: {default_causal_lm.tokenizer.decode(next_batch.all_input_ids[0][-1])}") - assert next_batch.all_input_ids[0][-2] == 14402 assert torch.all(next_batch.all_input_ids[0][:-2] == 50256) @@ -170,7 +167,7 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_ assert torch.all(next_batch.attention_mask[0][2:] == 0) assert next_batch.input_ids.shape == (len(next_batch), 1) - assert next_batch.input_ids[0, 0] == 13 + assert next_batch.input_ids[0, 0] == generated_token_id assert next_batch.input_lengths == [2] assert next_batch.max_input_length == next_batch.input_lengths[0] From b9e414a02e4c271a321a8496005a918a0f8c74f4 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 18:21:51 -0600 Subject: [PATCH 22/25] one more reference --- server/tests/models/test_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index b2e7deed0..43c4adb11 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -181,7 +181,7 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_ ) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) - assert all([generation.token_id.item() == 13 for generation in generations]) + assert all([generation.token_id.item() == generated_token_id for generation in generations]) assert all([generation.token_text == "." for generation in generations]) assert generations[0].request_id == 0 From 94eb8446d364aa65dbc869cf883863e1dc082726 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 18:44:21 -0600 Subject: [PATCH 23/25] fix token text --- server/tests/models/test_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 43c4adb11..3bbff2001 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -182,7 +182,7 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_ assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([generation.token_id.item() == generated_token_id for generation in generations]) - assert all([generation.token_text == "." for generation in generations]) + assert all([generation.token_text == "{" for generation in generations]) assert generations[0].request_id == 0 From 5e6a7f76cae5852d851923f77b4b12f4157a258a Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Fri, 9 Feb 2024 18:53:41 -0600 Subject: [PATCH 24/25] decode token --- server/tests/models/test_causal_lm.py | 71 ++++++++++++++------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 3bbff2001..e81a780cd 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -1,11 +1,11 @@ +from copy import copy + import pytest import torch - -from copy import copy from transformers import AutoTokenizer -from lorax_server.pb import generate_pb2 from lorax_server.models.causal_lm import CausalLM, CausalLMBatch +from lorax_server.pb import generate_pb2 from lorax_server.utils.tokenizer import TokenizerManager @@ -182,7 +182,8 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_ assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([generation.token_id.item() == generated_token_id for generation in generations]) - assert all([generation.token_text == "{" for generation in generations]) + assert all([generation.token_text == default_causal_lm.tokenizer.decode(generated_token_id) for generation in + generations]) assert generations[0].request_id == 0 @@ -201,8 +202,8 @@ def test_causal_lm_generate_token_completion( assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generations[0].generated_text.generated_tokens - == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + generations[0].generated_text.generated_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -212,7 +213,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch = default_multi_requests_causal_lm_batch for i in range( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -223,12 +224,12 @@ def test_causal_lm_generate_token_completion_multi( assert len(generations) == 2 assert generations[1].generated_text.text == ".java:784)" assert ( - generations[1].request_id - == default_multi_requests_causal_lm_batch.requests[1].id + generations[1].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generations[1].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + generations[1].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) # Copy stopping_criterias before filtering stopping_criterias = ( @@ -238,7 +239,7 @@ def test_causal_lm_generate_token_completion_multi( next_batch = next_batch.filter([next_batch.requests[0].id]) for _ in range( - stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 + stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -249,12 +250,12 @@ def test_causal_lm_generate_token_completion_multi( assert len(generations) == 1 assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( - generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].id + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generations[0].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + generations[0].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -287,7 +288,7 @@ def test_batch_concatenate( next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1 ) assert torch.all( - next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1 + next_batch.attention_mask[1:, 1: -next_batch.padding_right_offset] == 1 ) assert torch.all(next_batch.attention_mask[1:, 3:] == 0) @@ -323,7 +324,7 @@ def test_batch_concatenate( ) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 + default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -334,12 +335,12 @@ def test_batch_concatenate( assert len(generations) == 3 assert generations[2].generated_text.text == ".java:784)" assert ( - generations[2].request_id - == default_multi_requests_causal_lm_batch.requests[1].id + generations[2].request_id + == default_multi_requests_causal_lm_batch.requests[1].id ) assert ( - generations[2].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + generations[2].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) next_batch = next_batch.filter( @@ -347,9 +348,9 @@ def test_batch_concatenate( ) for _ in range( - default_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - - 2 + default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 2 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -361,17 +362,17 @@ def test_batch_concatenate( assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert ( - generations[0].generated_text.generated_tokens - == default_causal_lm_batch.stopping_criterias[0].max_new_tokens + generations[0].generated_text.generated_tokens + == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) next_batch = next_batch.filter([next_batch.requests[1].id]) for _ in range( - default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_causal_lm_batch.stopping_criterias[0].max_new_tokens - - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - - 4 + default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_causal_lm_batch.stopping_criterias[0].max_new_tokens + - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens + - 4 ): generations, next_batch = default_causal_lm.generate_token(next_batch) assert len(generations) == len(next_batch) @@ -382,10 +383,10 @@ def test_batch_concatenate( assert len(generations) == 1 assert generations[0].generated_text.text == ".java:784) at net.minecraft." assert ( - generations[0].request_id - == default_multi_requests_causal_lm_batch.requests[0].id + generations[0].request_id + == default_multi_requests_causal_lm_batch.requests[0].id ) assert ( - generations[0].generated_text.generated_tokens - == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens + generations[0].generated_text.generated_tokens + == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) From 0650fc62d2f1c9909177668fb30e68152b0b0c43 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Mon, 12 Feb 2024 16:59:15 -0600 Subject: [PATCH 25/25] fix pydantic errors around schema name clash --- clients/python/lorax/client.py | 3 ++- clients/python/lorax/types.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index b5aa2776c..b20dc6985 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -160,7 +160,7 @@ def generate( resp = requests.post( self.base_url, - json=request.dict(), + json=request.dict(by_alias=True), headers=self.headers, cookies=self.cookies, timeout=self.timeout, @@ -168,6 +168,7 @@ def generate( payload = resp.json() if resp.status_code != 200: raise parse_error(resp.status_code, payload) + return Response(**payload[0]) def generate_stream( diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index a34f0e612..26cdbd8ff 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, Field from typing import Optional, List from lorax.errors import ValidationError @@ -98,6 +98,8 @@ class Parameters(BaseModel): details: bool = False # Get decoder input token logprobs and ids decoder_input_details: bool = False + # Optional JSON schema string to constrain the generated text + json_schema: Optional[str] = Field(alias="schema") @validator("adapter_id") def valid_adapter_id(cls, v, values):