Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Outlines logits processor for JSON schema validation #224

Merged
merged 25 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
RUST_BACKTRACE: 1
SCCACHE: 0.3.3

steps:
Expand All @@ -33,7 +34,7 @@ jobs:
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: 1.70.0
toolchain: 1.74.0
override: true
components: rustfmt, clippy
- name: Install Protoc
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ COPY server/Makefile server/Makefile

RUN cd server && \
make gen-server && \
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir

# Install router
COPY --from=builder /usr/src/target/release/lorax-router /usr/local/bin/lorax-router
Expand Down
2 changes: 1 addition & 1 deletion clients/python/lorax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.3.0"
__version__ = "0.3.1"

from lorax.client import Client, AsyncClient, MergedAdapters
20 changes: 18 additions & 2 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from typing import Any, Dict, Optional, List, AsyncIterator, Iterator

from lorax.types import (
StreamResponse,
Expand Down Expand Up @@ -79,6 +79,7 @@ def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
schema: Optional[Dict[str, Any]] = None,
decoder_input_details: bool = False,
) -> Response:
"""
Expand Down Expand Up @@ -124,6 +125,8 @@ def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
schema (`Optional[Dict[str, Any]]`):
Optional JSON schema to validate the response
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids

Expand All @@ -150,20 +153,22 @@ def generate(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
schema=json.dumps(schema) if schema is not None else None,
decoder_input_details=decoder_input_details,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)

resp = requests.post(
self.base_url,
json=request.dict(),
json=request.dict(by_alias=True),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)

return Response(**payload[0])

def generate_stream(
Expand All @@ -185,6 +190,7 @@ def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
schema: Optional[Dict[str, Any]] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Expand Down Expand Up @@ -227,6 +233,8 @@ def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
schema (`Optional[Dict[str, Any]]`):
Optional JSON schema to validate the response

Returns:
Iterator[StreamResponse]: stream of generated tokens
Expand All @@ -252,6 +260,7 @@ def generate_stream(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
schema=json.dumps(schema) if schema is not None else None,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

Expand Down Expand Up @@ -353,6 +362,7 @@ async def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
schema: Optional[Dict[str, Any]] = None,
decoder_input_details: bool = False,
) -> Response:
"""
Expand Down Expand Up @@ -398,6 +408,8 @@ async def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
schema (`Optional[Dict[str, Any]]`):
Optional JSON schema to validate the response
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids

Expand Down Expand Up @@ -425,6 +437,7 @@ async def generate(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
schema=json.dumps(schema) if schema is not None else None,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)

Expand Down Expand Up @@ -457,6 +470,7 @@ async def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
schema: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Expand Down Expand Up @@ -499,6 +513,8 @@ async def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
schema (`Optional[Dict[str, Any]]`):
Optional JSON schema to validate the response

Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
Expand Down
4 changes: 3 additions & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, validator
from pydantic import BaseModel, validator, Field
from typing import Optional, List

from lorax.errors import ValidationError
Expand Down Expand Up @@ -98,6 +98,8 @@ class Parameters(BaseModel):
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# Optional JSON schema string to constrain the generated text
json_schema: Optional[str] = Field(alias="schema")

@validator("adapter_id")
def valid_adapter_id(cls, v, values):
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "lorax-client"
packages = [
{include = "lorax"}
]
version = "0.3.0"
version = "0.3.1"
description = "LoRAX Python Client"
license = "Apache-2.0"
authors = ["Travis Addair <[email protected]>", "Olivier Dehaene <[email protected]>"]
Expand Down
6 changes: 6 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,12 @@
"default": "false",
"example": true
},
"schema": {
"type": "string",
"default": "null",
"example": "{\"type\": \"string\", \"title\": \"response\"}",
"nullable": true
},
"adapter_id": {
"type": "string",
"nullable": true
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ message NextTokenChooserParameters {
bool watermark = 8;
/// adapter to use with lora exchange
string adapter_id = 9;
/// JSON schema used for constrained decoding (Outlines)
optional string schema = 10;
}

message StoppingCriteriaParameters {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl Client {
repetition_penalty: 1.2,
watermark: true,
adapter_id: "".to_string(),
schema: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 2,
Expand Down
1 change: 1 addition & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl Health {
repetition_penalty: 1.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
Expand Down
10 changes: 10 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,13 @@ pub(crate) struct GenerateParameters {
example = "null"
)]
pub seed: Option<u64>,
#[serde(default)]
#[schema(
nullable = true,
default = "null",
example = "{\"type\": \"string\", \"title\": \"response\"}"
)]
pub schema: Option<String>,
}

fn default_max_new_tokens() -> u32 {
Expand Down Expand Up @@ -277,6 +284,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
apply_chat_template: false,
seed: None,
schema: None,
}
}

Expand Down Expand Up @@ -582,6 +590,7 @@ impl From<CompletionRequest> for CompatGenerateRequest {
decoder_input_details: req.logprobs.is_some(),
apply_chat_template: false,
seed: None,
schema: None,
},
stream: req.stream.unwrap_or(false),
}
Expand Down Expand Up @@ -616,6 +625,7 @@ impl From<ChatCompletionRequest> for CompatGenerateRequest {
decoder_input_details: false,
apply_chat_template: true,
seed: None,
schema: None,
},
stream: req.stream.unwrap_or(false),
}
Expand Down
2 changes: 2 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl Validation {
adapter_parameters,
decoder_input_details,
apply_chat_template,
schema,
..
} = request.parameters;

Expand Down Expand Up @@ -273,6 +274,7 @@ impl Validation {
seed,
watermark,
adapter_id,
schema,
};
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[toolchain]
channel = "1.70.0"
channel = "1.74.0"
components = ["rustfmt", "clippy"]
2 changes: 1 addition & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ install: gen-server
pip install pip --upgrade
pip install torch==2.2.0
pip install -r requirements.txt
pip install -e ".[bnb, accelerate, quantize, peft]"
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"

run-dev:
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve gpt2
Expand Down
Loading
Loading