Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validator, logging and modelling improvements #127

Merged
merged 14 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/tanuki/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ def wrapper(*args, **kwargs) -> Union[Embedding, Any]:
# Configure the function modeler using incoming parameters
function_modeler.environment_id = environment_id
if ignore_finetuning:
logging.info(f"The flag for ignoring finetuning has been set True for {test_func.__name__}. No model distillation will be performed.")
function_modeler.execute_finetune_blacklist.append(func_hash)
if ignore_finetune_fetching:
logging.info(f"The flag for ignoring searching for finetuned models has been set True for {test_func.__name__}. No already finetuned models will be looked for.")
function_modeler.check_finetune_blacklist.append(func_hash)
if ignore_data_storage:
logging.info(f"The flag for ignoring data storage has been set True for {test_func.__name__}. No data will be read or saved and model distillation will not be performed.")
function_modeler.store_data_blacklist.append(func_hash)
task_type = function_description.type
if len(teacher_models) > 0:
Expand Down
2 changes: 1 addition & 1 deletion src/tanuki/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# default models
DEFAULT_TEACHER_MODEL_NAMES = ["gpt-4", "gpt-4-32k", ]
DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-finetune"
DEFAULT_DISTILLED_MODEL_NAME = "gpt-3.5-turbo-1106"
DEFAULT_EMBEDDING_MODEL_NAME = "ada-002"

# provider names
Expand Down
36 changes: 23 additions & 13 deletions src/tanuki/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import json
from typing import List, Tuple, Dict, Union

import openai
import logging

from tanuki.constants import EXAMPLE_ELEMENT_LIMIT, PATCHES, SYMBOLIC_ALIGNMENTS, POSITIVE_EMBEDDABLE_ALIGNMENTS, \
NEGATIVE_EMBEDDABLE_ALIGNMENTS, OPENAI_PROVIDER
from tanuki.models.function_type import FunctionType
from tanuki.language_models.llm_configs import DEFAULT_GENERATIVE_MODELS, DEFAULT_EMBEDDING_MODELS
from tanuki.language_models.llm_configs import DEFAULT_TEACHER_MODELS, DEFAULT_EMBEDDING_MODELS
from tanuki.language_models.llm_configs.abc_base_config import BaseModelConfig
from tanuki.language_models.llm_finetune_api_abc import LLM_Finetune_API
from tanuki.models.finetune_job import FinetuneJob
Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self, data_worker: DatasetWorker,
self.store_data_blacklist = []
self.api_provider = api_provider
self.teacher_models_override = {}
self.startup_logging_checker = {}

def _get_dataset_info(self, dataset_type, func_hash, type="length"):
"""
Expand All @@ -65,7 +66,7 @@ def _configure_teacher_models(self,
if task_type == FunctionType.EMBEDDABLE:
preconfigured_models = DEFAULT_EMBEDDING_MODELS
elif task_type == FunctionType.SYMBOLIC:
preconfigured_models = DEFAULT_GENERATIVE_MODELS
preconfigured_models = DEFAULT_TEACHER_MODELS
for model in teacher_models:
if isinstance(model, str):
if model not in preconfigured_models:
Expand Down Expand Up @@ -318,7 +319,7 @@ def load_function_config(self, func_hash, function_description):

def _check_for_finetunes(self, function_description: FunctionDescription, finetune_provider : str) -> Tuple[bool, Dict]:
# hash the function_hash into 16 characters (to embed it into the name of OpenAI finetunes, for later retrieval)

logging.info(f"Checking for finetunes for {function_description.name} using {finetune_provider}")
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.environment_id)
# List 10 fine-tuning jobs
finetunes: List[FinetuneJob] = self.api_provider[finetune_provider].list_finetuned(limit=1000)
Expand All @@ -333,10 +334,12 @@ def _check_for_finetunes(self, function_description: FunctionDescription, finetu
config = self._construct_config_from_finetune(finetune_hash, finetune)
# save the config
self.data_worker.update_function_config(function_description.__hash__(), config)
logging.info(f"Found finetuned model for {function_description.name} [{config.distilled_model.model_name}]")
return True, config
except:
logging.info(f"Found finetuned model for {function_description.name} [{finetune.fine_tuned_model.model_name}] but could not load it")
return False, {}

logging.info(f"No finetuned model found for {function_description.name}")
return False, {}

def _construct_config_from_finetune(self, finetune_hash: str, finetune: FinetuneJob):
Expand Down Expand Up @@ -426,16 +429,16 @@ def check_for_finetuning(self, function_description, func_hash):
# check if already finetuning
if "job_id" in self.function_configs[func_hash].current_training_run:
# check for job status
self._check_finetuning_status(func_hash)
self._check_finetuning_status(func_hash, function_description)
else:
# check for finetuning condition
if self._check_finetuning_condition(func_hash):
if self._check_finetuning_condition(func_hash, function_description):
self._execute_finetuning(function_description, func_hash)
except Exception as e:
print(e)
print("Error checking for finetuning")

def _check_finetuning_condition(self, func_hash):
def _check_finetuning_condition(self, func_hash, function_description):
"""
Check if the finetuning condition is met
Currently finetuning condition is dependent on the number of symbolic datapoints since last finetuning
Expand All @@ -453,6 +456,11 @@ def _check_finetuning_condition(self, func_hash):
# if havent read in the patch dataset size, read it in
patch_dataset_size = self._get_dataset_info(PATCHES, func_hash, type="length")
self.dataset_sizes[PATCHES][func_hash] = patch_dataset_size
if func_hash not in self.startup_logging_checker:
logging.info(f"Function {function_description.name} [{align_dataset_size} aligns | {patch_dataset_size} runs] will be finetuned from"\
f" {self.function_configs[func_hash].teacher_models[0].model_name} using {self.function_configs[func_hash].distilled_model.provider} in "\
f"{training_threshold-(patch_dataset_size + align_dataset_size)} runs")
self.startup_logging_checker[func_hash] = True

return (patch_dataset_size + align_dataset_size) > training_threshold

Expand Down Expand Up @@ -529,8 +537,10 @@ def _execute_finetuning(self, function_description, func_hash):
# Use the stream as a file
try:
finetune_provider = self.function_configs[func_hash].distilled_model.provider
logging.info(f"Starting finetuning for {function_description.name} using {finetune_provider}")
finetuning_response: FinetuneJob = self.api_provider[finetune_provider].finetune(file=temp_file, suffix=finetune_hash)
except Exception as e:
logging.info(f"Could not start finetuning for {function_description.name} using {finetune_provider}. Error: {e}")
return

self.function_configs[func_hash].current_training_run = {"job_id": finetuning_response.id,
Expand All @@ -544,7 +554,7 @@ def _execute_finetuning(self, function_description, func_hash):
print(e)
print("Could not update config file to register a finetuning run")

def _check_finetuning_status(self, func_hash):
def _check_finetuning_status(self, func_hash, function_description):
"""
Check the status of the current finetuning job
If the job is finished, update the config file to reflect the new model
Expand All @@ -560,18 +570,18 @@ def _check_finetuning_status(self, func_hash):
self.function_configs[func_hash].current_training_run["last_checked"] = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")
if response.status == "succeeded" or response.status == "failed":
self._update_finetune_config(response, func_hash)
self._update_finetune_config(response, func_hash, function_description)
else:
self._update_config_file(func_hash)

def _update_finetune_config(self, response: FinetuneJob, func_hash):
def _update_finetune_config(self, response: FinetuneJob, func_hash, function_description):
"""
Update the config file to reflect the new model and switch the current model to the finetuned model
"""
self.function_configs[func_hash].update_with_finetuned_response(response)
logging.info(f"Finetuning for {function_description.name} using {self.function_configs[func_hash].distilled_model.provider} finished with status: {response.status}")
try:
self._update_config_file(func_hash)
except Exception as e:
print(e)
print("Could not update config file after a successful finetuning run")
logging.info(f"Could not update the function configuration file with the finetuned model for {function_description.name}. Error: {e}")
pass
14 changes: 7 additions & 7 deletions src/tanuki/language_models/embedding_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class EmbeddingModelManager(object):
def __init__(self, function_modeler, api_provider: APIManager):
self.function_modeler = function_modeler
self.api_provider = api_provider
self.current_generators = {}
self.initialized_functions = {}

def get_embedding_case(self, args, function_description: FunctionDescription, kwargs, examples=None):
# example_input = f"Examples:{examples}\n" if examples else ""
Expand All @@ -25,12 +25,12 @@ def get_embedding_case(self, args, function_description: FunctionDescription, kw


# loggings
if function_hash not in self.current_generators:
logging.info(f"Generating function embeddings with {model.model_name}")
self.current_generators[function_hash] = model.model_name
elif self.current_generators[function_hash] != model.model_name:
logging.info(f"Switching embeddings generation from {self.current_generators[function_hash]} to {model.model_name}")
self.current_generators[function_hash] = model.model_name
if function_hash not in self.initialized_functions:
logging.info(f"Generating function embeddings for {function_description.name} with {model.model_name}")
self.initialized_functions[function_hash] = model.model_name
elif self.initialized_functions[function_hash] != model.model_name:
logging.info(f"Switching embeddings generation for {function_description.name} from {self.initialized_functions[function_hash]} to {model.model_name}")
self.initialized_functions[function_hash] = model.model_name

return content, model

Expand Down
41 changes: 28 additions & 13 deletions src/tanuki/language_models/language_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,
self.api_provider = api_provider
self.function_modeler = function_modeler
self.default_generation_length = generation_token_limit
self.current_generators = {}
self.initialized_functions = {}
self.token_counts = {}

def __call__(self,
Expand Down Expand Up @@ -83,17 +83,24 @@ def generate(self, args, kwargs, function_description, llm_parameters={}):
The main generation function, given the args, kwargs, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset
"""

func_hash = function_description.__hash__()
prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args, kwargs,
function_description,
llm_parameters)
func_hash = function_description.__hash__()
llm_parameters,
func_hash)
# loggings
if func_hash not in self.current_generators:
logging.info(f"Generating function outputs with {model.model_name}")
self.current_generators[func_hash] = model.model_name
elif self.current_generators[func_hash] != model.model_name:
logging.info(f"Switching output generation from {self.current_generators[func_hash]} to {model.model_name}")
self.current_generators[func_hash] = model.model_name
current_function_setup = self.initialized_functions.get(func_hash, None) # getting the current function setup - model and align statements
if current_function_setup:
generator_model = current_function_setup["model"]
if is_distilled_model:
logging.info(f"Generating function outputs for {function_description.name} with a finetuned model: {model.model_name}.")
self.initialized_functions[func_hash]["model"] = model.model_name
elif generator_model == "":
logging.info(f"Found {len(current_function_setup['examples'])} align statements for {function_description.name}. Generating function outputs with {model.model_name}.")
self.initialized_functions[func_hash]["model"] = model.model_name
elif generator_model != model.model_name:
logging.info(f"Switching output generation from {generator_model} to {model.model_name} for function {function_description.name}.")
self.initialized_functions[func_hash]["model"] = model.model_name

choice = self._synthesise_answer(prompt, model, llm_parameters)
output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model)
Expand All @@ -114,7 +121,7 @@ def _synthesise_answer(self, prompt, model, llm_parameters):
return self.api_provider[model.provider].generate(model, system_message, prompt, **llm_parameters)


def get_generation_case(self, args, kwargs, function_description, llm_parameters):
def get_generation_case(self, args, kwargs, function_description, llm_parameters, func_hash):
"""
Get the generation case with the correct prompt and model
First get the current model, then if distilled model, do zero-shot prompt and return False as suitable_for_finetune
Expand All @@ -126,6 +133,9 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
is_distilled_model = distilled_model.model_name != ""
suitable_for_distillation, input_prompt_token_count = self.suitable_for_finetuning_token_check(args, kwargs, f,
distilled_model)
if func_hash not in self.initialized_functions:
# initialise the initialized_functions dict
self.initialized_functions[func_hash] = {"model": "", "examples": []}
# no examples needed, using a finetuned model. Dont save to finetune dataset
if is_distilled_model and suitable_for_distillation:
prompt = self.construct_prompt(f, args, kwargs, [], distilled_model)
Expand All @@ -136,13 +146,18 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
examples = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput: {align['output']}" for align in
aligns]

# update the examples in the initialized_functions dict
self.initialized_functions[func_hash]["examples"] = examples

examples_token_count = sum([approximate_token_count(example) for example in examples])
generation_tokens = llm_parameters.get("max_new_tokens", self.default_generation_length)
model = self.choose_model_from_tokens(teacher_models,
examples_token_count + input_prompt_token_count + generation_tokens,
len(examples))
if model:
prompt = self.construct_prompt(f, args, kwargs, examples, model)
examples_with_parsing_tokens = [f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput:{model.parsing_helper_tokens['start_token']}{align['output']}{model.parsing_helper_tokens['end_token']}" for align in
aligns]
prompt = self.construct_prompt(f, args, kwargs, examples_with_parsing_tokens, model)
return prompt, model, suitable_for_distillation, False
else:
raise ValueError(
Expand Down Expand Up @@ -179,14 +194,14 @@ def construct_prompt(self, f, args, kwargs, examples, model):
"""
if examples:
final_examples = "\n".join(
[f"{model.parsing_helper_tokens['start_token']}{align}{model.parsing_helper_tokens['end_token']}" for align in
[f"{align}" for align in
examples])
example_input = f"Examples:{final_examples}\n"
else:
example_input = ""

instruction_prompt = model.instructions
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\n{model.parsing_helper_tokens['start_token']}Inputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:"
content = f"{instruction_prompt}\nFunction: {f}\n{example_input}---\nInputs:\nArgs: {args}\nKwargs: {kwargs}\nOutput:{model.parsing_helper_tokens['start_token']}"
return content

def repair_generate(self, args, kwargs, f, failed_outputs_list, aligns, models, llm_parameters):
Expand Down
6 changes: 5 additions & 1 deletion src/tanuki/language_models/llama_bedrock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,9 @@ def generate(self, model: BaseModelConfig, system_message: str, prompt: str, **k
if model.parsing_helper_tokens["end_token"]:
# remove the end token from the choice
choice = choice.split(model.parsing_helper_tokens["end_token"])[0]
# check if starting token is in choice
if model.parsing_helper_tokens["start_token"] in choice:
# remove the starting token from the choice
choice = choice.split(model.parsing_helper_tokens["start_token"])[-1]

return choice
return choice.strip()
Loading