Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jan 22, 2025
2 parents a3320ed + 8717dd1 commit 43d1bb1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
15 changes: 11 additions & 4 deletions src/distilabel/models/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

if TYPE_CHECKING:
from litellm import Choices
from litellm.types.utils import ModelResponse
from pydantic import BaseModel


class LiteLLM(AsyncLLM):
Expand Down Expand Up @@ -124,7 +126,7 @@ def load(self) -> None:
client=self._aclient,
framework="litellm",
)
self._aclient = result.get("client")
self._aclient = result.get("client").messages.create
if structured_output := result.get("structured_output"):
self.structured_output = structured_output

Expand Down Expand Up @@ -209,7 +211,7 @@ async def agenerate( # type: ignore # noqa: C901
client=self._aclient,
framework="litellm",
)
self._aclient = result.get("client")
self._aclient = result.get("client").messages.create

if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
Expand Down Expand Up @@ -244,8 +246,13 @@ async def agenerate( # type: ignore # noqa: C901
async def _call_aclient_until_n_choices() -> List["Choices"]:
choices = []
while len(choices) < num_generations:
completion = await self._aclient(**kwargs) # type: ignore
if not self.structured_output:
completion: Union["ModelResponse", "BaseModel"] = await self._aclient(
**kwargs
) # type: ignore
if self.structured_output:
# Prevent pydantic model from being cast to list during list extension
completion = [completion]
else:
completion = completion.choices
choices.extend(completion)
return choices
Expand Down
88 changes: 62 additions & 26 deletions src/distilabel/steps/tasks/structured_outputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,34 @@ def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]:
return create_model(model_name, **field_definitions)


def resolve_refs(schema_element: Any, defs: Optional[Dict[str, Any]] = None) -> Any:
"""Resolves JSON schema references.
Args:
schema_element: The JSON schema or part of it to resolve.
defs: The definitions of the JSON schema.
Returns:
The resolved JSON schema.
"""
if isinstance(schema_element, dict):
# If the schema contains a $ref, resolve it
if "$ref" in schema_element and len(schema_element) == 1:
ref_key = schema_element["$ref"].split("/")[-1]
resolved = defs.get(ref_key, {})
return resolve_refs(resolved, defs) # Resolve recursively
else:
# Recursively resolve properties of the dictionary
return {
key: resolve_refs(value, defs) for key, value in schema_element.items()
}
elif isinstance(schema_element, list):
# Resolve each item in the list
return [resolve_refs(item, defs) for item in schema_element]
# Return other types as-is
return schema_element


def json_schema_to_pydantic_field(
name: str,
json_schema: Dict[str, Any],
Expand All @@ -90,13 +118,7 @@ def json_schema_to_pydantic_field(
# NOTE(plaguss): This needs more testing, nested classes need extra work to be converted
# here if we pass a reference to another class it will crash, we have to find the original
# definition and insert it here
# This takes into account single items referred to other classes
if ref := json_schema.get("$ref"):
json_schema = defs.get(ref.split("/")[-1])

# This takes into account lists of items referred to other classes
if "items" in json_schema and (ref := json_schema["items"].get("$ref")):
json_schema["items"] = defs.get(ref.split("/")[-1])
json_schema = resolve_refs(json_schema, defs)

# Get the field type.
type_ = json_schema_to_pydantic_type(json_schema)
Expand All @@ -119,6 +141,29 @@ def json_schema_to_pydantic_field(
)


def handle_any_of(json_schema: Dict[str, Any]) -> Any:
"""Handle 'anyOf' in JSON schema."""
types = [json_schema_to_pydantic_type(schema) for schema in json_schema["anyOf"]]
return Union[tuple(types)]


def handle_array(json_schema: Dict[str, Any]) -> Any:
"""Handle 'array' type in JSON schema."""
items_schema = json_schema.get("items")
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
return List[item_type]
return List


def handle_object(json_schema: Dict[str, Any]) -> Any:
"""Handle 'object' type in JSON schema."""
properties = json_schema.get("properties")
if properties:
return json_schema_to_model(json_schema)
return Dict


def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any:
"""Converts a JSON schema type to a Pydantic type.
Expand All @@ -128,34 +173,25 @@ def json_schema_to_pydantic_type(json_schema: Dict[str, Any]) -> Any:
Returns:
A Pydantic type.
"""
if "anyOf" in json_schema:
return handle_any_of(json_schema)

type_ = json_schema.get("type")

if type_ == "string":
type_val = str
return str
elif type_ == "integer":
type_val = int
return int
elif type_ == "number":
type_val = float
return float
elif type_ == "boolean":
type_val = bool
return bool
elif type_ == "array":
items_schema = json_schema.get("items")
if items_schema:
item_type = json_schema_to_pydantic_type(items_schema)
type_val = List[item_type]
else:
type_val = List
return handle_array(json_schema)
elif type_ == "object":
# Handle nested models.
properties = json_schema.get("properties")
if properties:
nested_model = json_schema_to_model(json_schema)
type_val = nested_model
else:
type_val = Dict
return handle_object(json_schema)
elif type_ == "null":
type_val = Optional[Any] # Use Optional[Any] for nullable fields
return Optional[Any] # Use Optional[Any] for nullable fields
else:
raise ValueError(f"Unsupported JSON schema type: {type_}")

return type_val

0 comments on commit 43d1bb1

Please sign in to comment.