Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSPy Only Predict #567

Merged
merged 26 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ba42637
feat: make template extraction greedy and accomodate for incomplete g…
KCaverly Mar 5, 2024
9eb5074
feat: added recursive recovery for lm template calls
KCaverly Mar 5, 2024
b586c3e
fix: update lm kwargs during recovery
KCaverly Mar 6, 2024
f2411f6
feat: added history to backend
KCaverly Mar 6, 2024
f866cd9
feat: removed all dsp objects from Predict
KCaverly Mar 6, 2024
4fe6503
fix: remove lm history
KCaverly Mar 7, 2024
e442115
feat: add JSONBackend
KCaverly Mar 7, 2024
713c826
chore: added tests for JSONBackend
KCaverly Mar 7, 2024
e15ae4e
feat: add full lm history tracking, and standardized Completion, Comp…
KCaverly Mar 8, 2024
6d66ee0
fix: small changes to default_factory fields for Backend and BaseLM h…
KCaverly Mar 10, 2024
96b3eb7
fix: simplify DummyLanguageModel
KCaverly Mar 11, 2024
1bdfcf0
chore: replace Completion with Example to simplify generation
KCaverly Mar 11, 2024
05d6bac
feat: updated aggregation for new Completions
KCaverly Mar 11, 2024
20c8098
feat: updated series of predict modules to leverage the TemplateBackend
KCaverly Mar 11, 2024
c449df7
fix: passing all predict tests
KCaverly Mar 11, 2024
fc41d55
fix: update tests for evaluate
KCaverly Mar 11, 2024
484368b
fix: update tests for primitives/program
KCaverly Mar 11, 2024
8a84188
fix: passing all functional tests
KCaverly Mar 12, 2024
51db259
wip: progress on signature optimization tests
KCaverly Mar 12, 2024
e83ba6b
chore: update formatting for signature_opt
KCaverly Mar 12, 2024
c5a96f1
fix: update teleprompt for signature_opt for updated Completions api
KCaverly Mar 12, 2024
a062370
fix: autoformat on bootstrap
KCaverly Mar 13, 2024
79be6fa
fix: run autoformatting on signature_opt_bayesian
KCaverly Mar 13, 2024
560ec20
fix: passing all signature_opt_bayesian tests
KCaverly Mar 13, 2024
53ac9af
fix: passing all bootstrap tests
KCaverly Mar 13, 2024
477f90a
fix: remove artifacts from pring debugging
KCaverly Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dspy/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .template import TemplateBackend
111 changes: 105 additions & 6 deletions dspy/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,120 @@
import typing as t

from pydantic import BaseModel
from dspy.primitives.prediction import Completions
from dspy.primitives.example import Example

from dspy.signatures.signature import Signature
from dspy.signatures.signature import Signature, ensure_signature


GeneratedOutput = t.TypeVar("GeneratedOutput", bound=dict)
class Completion:
def __init__(self, example: Example, complete: bool):
self.example = example
self.complete = complete

def __len__(self) -> int:
return len(self.example.keys())


def convert_to_completion(signature: Signature, example: Example) -> Completion:
complete = True
for field in signature.output_fields:
if field not in example:
complete = False

return Completion(example=example, complete=complete)


BackendEvent = t.TypeVar("BackendEvent", bound=dict[str, t.Any])


class BaseBackend(BaseModel, ABC):
"""A backend takes a signature, its params, and returns a list of structured predictions."""

@abstractmethod
_history: t.List[BackendEvent] = []

def __call__(
self,
signature: Signature,
recover: bool = False,
max_generations: int = 5,
**kwargs,
) -> Completions:
# Recursively complete generation, until at least one complete completion is available.
signature = ensure_signature(signature)

event = {
"signature": signature,
"recover": recover,
"max_generations": max_generations,
"kwargs": kwargs,
}
complete_examples = None
i = 0

while i < max_generations:
# Returns a List of Completions
# which may or may not be complete
output = self.generate(signature, **kwargs)

# Filter for only complete completions
complete_examples = [
completion.example for completion in output if completion.complete
]

# If 1 or more complete generations exist, simply return all
if len(complete_examples) > 0:
print(
"BREAKING DURING RECOVERY AS AT LEAST ONE COMPLETE EXAMPLE EXISTS"
)
break
# if not, recursively generation with the furthest generation as the example
elif recover:
max_example = output[0].example

for completion in output:
if len(max_example) < len(completion.example):
max_example = completion.example

# Currently this is only updating the example fields
# we will want to update kwargs aswell for max_tokens
# temperature etc.
for field in signature.fields:
if field in max_example:
kwargs[field] = max_example.get(field)

# Update lm arguments
# Setting temperature to 0.0, leads to greedy decoding
kwargs["temperature"] = 0.0

# Cut the max_tokens in half each time if it has been set
if "max_tokens" in kwargs:
kwargs["max_tokens"] = int(kwargs["max_tokens"] / 2)

else:
break

i += 1

if complete_examples is None:
raise Exception(
"Generation failed, recursively attempts to complete did not succeed."
)

completions = Completions(complete_examples)
event["completions"] = completions
self._history.append(event)

return completions

@property
def history(self) -> t.List[BackendEvent]:
return self._history

@abstractmethod
def generate(
self,
signature: Signature,
**kwargs,
) -> list[GeneratedOutput]:
"""Generates `n` predictions for the signature output."""
pass
) -> t.List[Completion]:
"""Generates `n` predictions (complete/partial) for the signature output."""
1 change: 1 addition & 0 deletions dspy/backends/lm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .litellm import *
10 changes: 4 additions & 6 deletions dspy/backends/lm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
_cachedir = os.environ.get("DSP_CACHEDIR") or str(Path.home() / ".joblib_cache")
_cache_memory = Memory(_cachedir, verbose=0)

Completion = t.TypeVar("Completion", bound=dict)


MinimalLM = t.Callable[[str, float, int, int], list[Completion]]
GeneratedOutput = t.TypeVar("GeneratedOutput", bound=dict[str, t.Any])
LMEvent = t.TypeVar("LMEvent", bound=dict[str, t.Any])


class BaseLM(BaseModel, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._generate_with_cache = _cache_memory.cache(self.generate)

def __call__(self, prompt: str, **kwargs) -> list[str]:
def __call__(self, prompt: str, **kwargs) -> list[GeneratedOutput]:
"""Generates `n` predictions for the signature output."""
generator = self._generate_with_cache if dspy.settings.cache else self.generate
return generator(prompt, **kwargs)
Expand All @@ -32,7 +30,7 @@ def generate(
self,
prompt: str,
**kwargs,
) -> list[Completion]:
) -> list[GeneratedOutput]:
"""Generates `n` predictions for the signature output."""
...

Expand Down
7 changes: 2 additions & 5 deletions dspy/backends/lm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from pydantic import Field


from .base import BaseLM


Choice = t.TypeVar("Choice", bound=dict[str, t.Any])
from .base import BaseLM, GeneratedOutput


class LiteLM(BaseLM):
Expand All @@ -26,7 +23,7 @@ def generate(
self,
prompt: str,
**kwargs,
) -> list[Choice]:
) -> list[GeneratedOutput]:
"""Generates `n` predictions for the signature output."""
options = {**self.STANDARD_PARAMS, **self.default_params, **kwargs}
response = completion(
Expand Down
19 changes: 9 additions & 10 deletions dspy/backends/template.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dspy.primitives.prediction import Completions
from dspy.signatures.signature import Signature, signature_to_template
from dspy.primitives.example import Example
from dspy.primitives.template import Template

import typing as t
from .base import BaseBackend, GeneratedOutput
from .base import BaseBackend, Completion, convert_to_completion
from .lm.litellm import BaseLM


Expand All @@ -12,12 +13,16 @@ class TemplateBackend(BaseBackend):

lm: BaseLM

def __call__(
@property
def lm_history(self):
return self.lm.history

KCaverly marked this conversation as resolved.
Show resolved Hide resolved
def generate(
self,
signature: Signature,
demos: t.List[str] = [],
**kwargs,
) -> list[GeneratedOutput]:
) -> t.List[Completion]:
"""Wrap the signature and demos into an example, and pass through the Language Model, returning Signature compliant output"""

if not all(k in kwargs for k in signature.input_fields):
Expand All @@ -44,10 +49,4 @@ def __call__(
template.extract(example, prediction.message.content) for prediction in pred
]

outputs = []
for extract in extracted:
outputs.append({})
for key in signature.model_fields:
outputs[-1][key] = extract.get(key)

return outputs
return [convert_to_completion(signature, example) for example in extracted]
77 changes: 24 additions & 53 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dsp
import dspy
import random

from dspy.predict.parameter import Parameter
from dspy.primitives.prediction import Prediction

from dspy.signatures.signature import ensure_signature, signature_to_template


class Predict(Parameter):
def __init__(self, signature, **config):
self.stage = random.randbytes(8).hex()
Expand All @@ -14,20 +15,22 @@ def __init__(self, signature, **config):
self.reset()

def reset(self):
self.lm = None
self.backend = None
self.traces = []
self.train = []
self.demos = []

def dump_state(self):
state_keys = ["lm", "traces", "train", "demos"]
state_keys = ["backend", "traces", "train", "demos"]
state = {k: getattr(self, k) for k in state_keys}

# Cache the signature instructions and the last field's name.
state["signature_instructions"] = self.signature.instructions

*_, last_key = self.signature.fields.keys()
state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra['prefix']
state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra[
"prefix"
]

return state

Expand All @@ -50,66 +53,34 @@ def __call__(self, **kwargs):

def forward(self, **kwargs):
# Extract the three privileged keyword arguments.
new_signature = ensure_signature(kwargs.pop("new_signature", None))
signature = ensure_signature(kwargs.pop("signature", self.signature))
signature = kwargs.pop("new_signature", kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))

# Get the right LM to use.
lm = kwargs.pop("lm", self.lm) or dsp.settings.lm
assert lm is not None, "No LM is loaded."

# If temperature is 0.0 but its n > 1, set temperature to 0.7.
temperature = config.get("temperature", None)
temperature = lm.kwargs["temperature"] if temperature is None else temperature

num_generations = config.get("n", None)
if num_generations is None:
num_generations = lm.kwargs.get("n", lm.kwargs.get("num_generations", None))

if (temperature is None or temperature <= 0.15) and num_generations > 1:
config["temperature"] = 0.7
# print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")
config = dict(**self.config, **kwargs.pop("config", {}), **kwargs)

# All of the other kwargs are presumed to fit a prefix of the signature.
# Get the right Backend to use.
backend = kwargs.pop("backend", self.backend) or dspy.settings.get(
"backend", None
)
assert backend is not None, "No Backend is configured."

x = dsp.Example(demos=demos, **kwargs)

if new_signature is not None:
signature = new_signature
x = dspy.Example(demos=demos, **kwargs)

if not all(k in kwargs for k in signature.input_fields):
present = [k for k in signature.input_fields if k in kwargs]
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

# Switch to legacy format for dsp.generate
template = signature_to_template(signature)

if self.lm is None:
x, C = dsp.generate(template, **config)(x, stage=self.stage)
else:
# Note: query_only=True means the instructions and examples are not included.
# I'm not really sure why we'd want to do that, but it's there.
with dsp.settings.context(lm=self.lm, query_only=True):
x, C = dsp.generate(template, **config)(x, stage=self.stage)

assert self.stage in x, "The generated (input, output) example was not stored"
print(
f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}."
)

completions = []
completions = backend(signature, **config)

for c in C:
completions.append({})
for field in template.fields:
if field.output_variable not in kwargs.keys():
completions[-1][field.output_variable] = getattr(
c, field.output_variable
)
# TODO: What purpose does stage play here?
# assert self.stage in x, "The generated (input, output) example was not stored"

Comment on lines +75 to +76
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@okhat can you comment here?

pred = Prediction.from_completions(completions, signature=signature)
pred = Prediction.from_completions(completions)

if kwargs.pop("_trace", True) and dsp.settings.trace is not None:
trace = dsp.settings.trace
trace = dspy.settings.get("trace")
if trace is not None and kwargs.pop("_trace", True):
trace.append((self, {**kwargs}, pred))

return pred
Expand Down
Loading
Loading