-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: draft dvc stages data pre-pro Signed-off-by: Valentin De Matos <[email protected]> * feat: draft DVC steps Signed-off-by: Valentin De Matos <[email protected]> * feat: dvc stages Signed-off-by: Valentin De Matos <[email protected]> * fix: get_random_data_samples now actually shuffle the dataset Signed-off-by: Valentin De Matos <[email protected]> * feat: store task id in row Signed-off-by: Valentin De Matos <[email protected]> * fix: inccease max lenght to 4k Signed-off-by: Valentin De Matos <[email protected]> * refactor: increase dataset size to 100k Signed-off-by: Valentin De Matos <[email protected]> * chore: remove deprecated code Signed-off-by: Valentin De Matos <[email protected]> * chore: remove deprecated code Signed-off-by: Valentin De Matos <[email protected]> * feat: draft of switching to parquet Signed-off-by: Valentin De Matos <[email protected]> * feat: draft of switching to parquet Signed-off-by: Valentin De Matos <[email protected]> * feat: switching to parquet Signed-off-by: Valentin De Matos <[email protected]> * fix: train and test set no longer are inverted Signed-off-by: Valentin De Matos <[email protected]> * refactor(generate_ready_to_use_dataset): keep only input_ids and labels Signed-off-by: Valentin De Matos <[email protected]> * fix(train): no longer load train set twice Signed-off-by: Valentin De Matos <[email protected]> * refactor(generate_ready_to_use_dataset): keep only input_ids for labels Signed-off-by: Valentin De Matos <[email protected]> * refactor(generate_ready_to_use_dataset): force clear cache Signed-off-by: Valentin De Matos <[email protected]> * feat: draft training stages Signed-off-by: Valentin De Matos <[email protected]> * fix(model): provides default params Signed-off-by: Valentin De Matos <[email protected]> * refactor(train): exepct path_to_dataset Signed-off-by: Valentin De Matos <[email protected]> * fix(dvc): stage generate-poc-dataset Signed-off-by: Valentin De Matos <[email protected]> * feat: draft training Signed-off-by: Valentin De Matos <[email protected]> * feat: draft training Signed-off-by: Valentin De Matos <[email protected]> * build: add depencies to req.txt Signed-off-by: Valentin De Matos <[email protected]> * feat: draft training v2 Signed-off-by: Valentin De Matos <[email protected]> * feat: draft training-v2 w/ causal Signed-off-by: Valentin De Matos <[email protected]> * doc: usage section in readme * feat: new module merge_lora_to_model Signed-off-by: Valentin De Matos <[email protected]> * chore: remove deprecated file Signed-off-by: Valentin De Matos <[email protected]> * refactor: change runs by date only Signed-off-by: Valentin De Matos <[email protected]> --------- Signed-off-by: Valentin De Matos <[email protected]>
- Loading branch information
Showing
31 changed files
with
775 additions
and
541 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,147 @@ | ||
general: | ||
project-name: &project_name StockLLM | ||
|
||
|
||
model: &model | ||
|
||
model_max_length: &model_max_length 2048 | ||
|
||
bitesandbytes_parameters: | ||
load_in_4bit: True | ||
bnb_4bit_use_double_quant: True | ||
bnb_4bit_quant_type: nf4 | ||
bnb_4bit_compute_dtype: bfloat16 | ||
|
||
|
||
training: | ||
use_bf16: &use_bf16 True | ||
optimizer: &optimizer paged_adamw_8bit | ||
|
||
stages: | ||
|
||
add-chessmoves-to-model-vocab: | ||
module: "model.add_chessmoves_to_vocabulary" | ||
params: | ||
model_parameters: *model | ||
path_to_output: &path-to-untrained-model-w-vocab "./outputs/untrained-model-w-vocab/" | ||
|
||
generate-causal-dataset: | ||
module: "data_processing.generate_causal_dataset" | ||
params: | ||
number_of_training_samples: 47500 | ||
number_of_testing_samples: 2500 | ||
path_to_output_dataset: &path-to-causal-dataset "./outputs/causal-dataset/" | ||
|
||
train-model-on-new-vocab: | ||
module: "training.train" | ||
params: | ||
path_to_model: *path-to-untrained-model-w-vocab | ||
path_to_dataset: *path-to-causal-dataset | ||
path_to_outputs: "./outputs/causal/" | ||
project_name: *project_name | ||
subproject_name: Causal | ||
model_parameters: | ||
<<: *model | ||
training_parameters: | ||
formatting_func: causal_formatting_prompts_func | ||
max_steps: 4000 | ||
warmup_steps: 3 | ||
eval_steps: 250 | ||
save_steps: 250 | ||
logging_steps: 25 | ||
per_device_train_batch_size: 16 | ||
per_device_eval_batch_size: 16 | ||
learning_rate: 2.0E-5 | ||
|
||
merge-causal-adapter-to-model: | ||
module: "model.merge_lora_to_model" | ||
params: | ||
path_to_model: *path-to-untrained-model-w-vocab | ||
path_to_adapter: "./outputs/causal/checkpoint-2000/" | ||
path_to_output: &path-to-causal-model "./outputs/causal-merged/" | ||
|
||
create-task-find-last-move: | ||
module: "tasks.create_task_find_last_move" | ||
params: | ||
number_of_samples: 15000 | ||
path_to_output_dataset: "outputs/tasks/findLastMove.parquet" | ||
path_to_output_dataset: "./outputs/tasks/findLastMove.parquet" | ||
|
||
create-task-find-score: | ||
module: "tasks.create_task_find_score" | ||
params: | ||
number_of_samples: 20000 | ||
path_to_output_dataset: "outputs/tasks/findScore.parquet" | ||
path_to_output_dataset: "./outputs/tasks/findScore.parquet" | ||
|
||
create-task-MLM-on-moves: | ||
module: "tasks.create_task_MLM_on_moves" | ||
params: | ||
number_of_samples: 15000 | ||
path_to_output_dataset: "outputs/tasks/MLM.parquet" | ||
|
||
# Considered deprecated: I'm not convinced it's helpful for the | ||
# model while it the most task consuming task to generate | ||
# create-task-find-best-n-worst-positions: | ||
# module: "tasks.create_task_find_best_and_worst_positions" | ||
# params: | ||
# number_of_samples: 20000 | ||
# path_to_output_dataset: "outputs/tasks/bestAndWorstPositions.parquet" | ||
path_to_output_dataset: "./outputs/tasks/MLM.parquet" | ||
|
||
create-task-find-who-is-winning: | ||
module: "tasks.create_task_find_who_is_winning" | ||
params: | ||
number_of_samples: 20000 | ||
path_to_output_dataset: "outputs/tasks/whoIsWinning.parquet" | ||
path_to_output_dataset: "./outputs/tasks/whoIsWinning.parquet" | ||
|
||
create-task-sort-FENs: | ||
module: "tasks.create_task_sort_FENs" | ||
params: | ||
number_of_samples: 10000 | ||
path_to_output_dataset: "outputs/tasks/sortFENs.parquet" | ||
path_to_output_dataset: "./outputs/tasks/sortFENs.parquet" | ||
|
||
create-task-find-next-best-move: | ||
module: "tasks.create_task_find_next_best_move" | ||
params: | ||
number_of_samples: 20000 | ||
path_to_output_dataset: "outputs/tasks/bestMove.parquet" | ||
path_to_output_dataset: "./outputs/tasks/bestMove.parquet" | ||
|
||
merge-tasks-into-dataset: | ||
module: "data_processing.merge_tasks_into_dataset" | ||
params: | ||
paths: | ||
- "outputs/tasks/findLastMove.parquet" | ||
- "outputs/tasks/findScore.parquet" | ||
- "outputs/tasks/MLM.parquet" | ||
# - "outputs/tasks/bestAndWorstPositions.parquet" | ||
- "outputs/tasks/whoIsWinning.parquet" | ||
- "outputs/tasks/sortFENs.parquet" | ||
- "outputs/tasks/bestMove.parquet" | ||
- "./outputs/tasks/findLastMove.parquet" | ||
- "./outputs/tasks/findScore.parquet" | ||
- "./outputs/tasks/MLM.parquet" | ||
- "./outputs/tasks/whoIsWinning.parquet" | ||
- "./outputs/tasks/sortFENs.parquet" | ||
- "./outputs/tasks/bestMove.parquet" | ||
test_size: 0.01 | ||
path_to_train_set: "outputs/raw/train.csv" | ||
path_to_test_set: "outputs/raw/test.csv" | ||
path_to_train_set: "./outputs/raw/train.csv" | ||
path_to_test_set: "./outputs/raw/test.csv" | ||
|
||
generate-ready-to-use-dataset: | ||
generate-instruct-dataset: | ||
module: "data_processing.generate_ready_to_use_dataset" | ||
params: | ||
path_to_test_set: "outputs/raw/test.csv" | ||
path_to_train_set: "outputs/raw/train.csv" | ||
path_to_output_dataset: "outputs/dataset/" | ||
path_to_test_set: "./outputs/raw/test.csv" | ||
path_to_train_set: "./outputs/raw/train.csv" | ||
path_to_output_dataset: &path-to-instruct-dataset "./outputs/instruct-dataset/" | ||
model_max_length: *model_max_length | ||
|
||
train-instruct-model: | ||
module: "training.train" | ||
params: | ||
path_to_outputs: "./outputs/instruct/" | ||
path_to_model: *path-to-causal-model | ||
path_to_dataset: *path-to-instruct-dataset | ||
project_name: *project_name | ||
subproject_name: Instruct | ||
model_parameters: | ||
<<: *model | ||
training_parameters: | ||
formatting_func: instruct_formatting_prompts_func | ||
max_steps: 2000 | ||
warmup_steps: 3 | ||
eval_steps: 100 | ||
save_steps: 100 | ||
logging_steps: 25 | ||
per_device_train_batch_size: 16 | ||
per_device_eval_batch_size: 16 | ||
learning_rate: 2.5E-5 | ||
|
||
merge-instruct-adapter-to-model: | ||
module: "model.merge_lora_to_model" | ||
params: | ||
path_to_model: *path-to-causal-model | ||
path_to_adapter: "./outputs/instruct/checkpoint-2000/" | ||
path_to_output: &path-to-instruct-model "./outputs/instruct-merged/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import io | ||
import json | ||
import chess.pgn | ||
|
||
from typing import List, Dict | ||
from .generate_PGNs import generate_pgn | ||
|
||
|
||
def causal_formatting_prompts_func(example: Dict[str, str]) -> List[str]: | ||
output_texts = [] | ||
|
||
def _create_sample(moves, result, termination): | ||
|
||
game = chess.pgn.read_game(io.StringIO(generate_pgn(moves))) | ||
game.headers['Result'] = result | ||
game.headers['Termination'] = termination | ||
|
||
return str(game) | ||
|
||
for idx in range(len(example['Moves'])): | ||
output_texts.append(_create_sample( | ||
example["Moves"][idx], | ||
example["Result"][idx], | ||
example["Termination"][idx], | ||
)) | ||
|
||
return output_texts | ||
|
||
|
||
def instruct_formatting_prompts_func(example: Dict[str, str]) -> List[str]: | ||
output_texts = [] | ||
|
||
def _format_dict_to_string(data): | ||
data = json.loads(data) | ||
|
||
return "\n" + "\n".join([f"{k}: {v}" for (k, v) in data.items()]) | ||
|
||
for idx in range(len(example['task'])): | ||
|
||
inputs = _format_dict_to_string(example["input"][idx]) | ||
expected_output = _format_dict_to_string(example["expected_output"][idx]) | ||
|
||
text = f"<s>[INST]{example['task'][idx]}[/INST]\n[IN]{inputs}[/IN]\n[OUT]{expected_output}[/OUT]</s>" | ||
output_texts.append(text) | ||
|
||
return output_texts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.