Skip to content

Commit

Permalink
Remove template in unload to fix error on offline batch generation
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Sep 13, 2024
1 parent 3aab909 commit 18b6efb
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/distilabel/steps/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,12 @@ def check_column_in_template(column, template):

for column in self.columns:
check_column_in_template(column, self.template)
# NOTE: This should work, but when running on use_offline_batch_generation fails with a
# pickling error. This is a workaround to avoid the error.
# self._template = Template(self.template)

self._template = Template(self.template)

def unload(self) -> None:
super().unload()
self._template = None

@property
def inputs(self) -> "StepColumns":
Expand All @@ -251,8 +254,7 @@ def inputs(self) -> "StepColumns":
def _prepare_message_content(self, input: Dict[str, Any]) -> "ChatType":
"""Prepares the content for the template and returns the formatted messages."""
fields = {column: input[column] for column in self.columns}
# return [{"role": "user", "content": self._template.render(**fields)}]
return [{"role": "user", "content": Template(self.template).render(**fields)}]
return [{"role": "user", "content": self._template.render(**fields)}]

def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
Expand Down

0 comments on commit 18b6efb

Please sign in to comment.