Skip to content

Commit

Permalink
Add recipes/A5000_24GB_x8/translator-en-ja-oasst1.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
yuiseki committed Mar 21, 2024
1 parent 27f0458 commit f651b0b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 58 deletions.
28 changes: 28 additions & 0 deletions recipes/A5000_24GB_x8/translator-en-ja-oasst1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
target_task: tasks/nlp/hate-speech-detection.md
base_model_id: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_name: tinyllama-translator-en-ja-oasst1-v1
output_base_dir: output
dataset_id: kunishou/oasst1-89k-ja
dataset_input_hint: Given the text, translate to Japanese.
dataset_input_field_name: text
dataset_output_field_name: text_ja
dataset_filter_field_name: lang
dataset_filter_field_value: en
dataset_train_split_seed: 42
dataset_train_split_test_size: 0.2
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
train_per_device_train_batch_size: 8
train_gradient_accumulation_steps: 4
train_num_train_epochs: 4
train_max_steps: 400
train_fp16: True
inference_max_new_tokens: 16
evaluations:
-
prompt: "thank you"
expected_output: "ありがとう"
-
prompt: "Hello"
expected_output: "こんにちは"
124 changes: 66 additions & 58 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,78 @@ def load_yaml(file_path):
filepath = sys.argv[1]
train_config = load_yaml(filepath)

#
# Template
#
def simple_template_for_train(input, output)->str:
template = f"""\
<|im_start|>user
{input}
<|im_end|>
<|im_start|>assistant
{output}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def hint_template_for_train(hint, question, answer):
template = f"""\
<|im_start|>user
{hint}
{question}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def context_template_for_train(context, question, answer):
template = f"""\
<|im_start|>user
{question}
{context}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def context_hint_template_for_train(hint, context, question, answer):
template = f"""\
<|im_start|>user
{hint}
context:
{context}
question:
{question}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

#
# Prepare train data
#
def prepare_train_data(dataset_id):
def simple_template_for_train(input, output)->str:
template = f"""\
<|im_start|>user
{input}
<|im_end|>
<|im_start|>assistant
{output}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def hint_template_for_train(hint, question, answer):
template = f"""\
<|im_start|>user
{hint}
{question}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def context_template_for_train(context, question, answer):
template = f"""\
<|im_start|>user
{question}
{context}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

def context_hint_template_for_train(hint, context, question, answer):
template = f"""\
<|im_start|>user
{hint}
context:
{context}
question:
{question}
<|im_end|>
<|im_start|>assistant
{answer}
<|im_end|>
"""
# Remove any leading whitespace characters from each line in the template.
template = "\n".join([line.lstrip() for line in template.splitlines()])
return template

data = load_dataset(dataset_id, split="train")

data_df = data.to_pandas()

if "dataset_filter_field_name" in train_config:
data_df = data_df[data_df[train_config['dataset_filter_field_name']] == train_config['dataset_filter_field_value']]

input_field_name = train_config['dataset_input_field_name']
output_field_name = train_config['dataset_output_field_name']
if "dataset_context_field_name" in train_config:
Expand All @@ -104,6 +111,7 @@ def context_hint_template_for_train(hint, context, question, answer):

data = Dataset.from_pandas(data_df)
data = data.train_test_split(seed=42, test_size=0.2)
print(len(data["train"]))
return data

dataset_id = train_config['dataset_id']
Expand Down

0 comments on commit f651b0b

Please sign in to comment.