forked from ai-hero/llm-research-fine-tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaunch.py
23 lines (16 loc) · 888 Bytes
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""run script for fine-tuning a model."""
import os
from aihero.research.config.schema import BatchInferenceJob, TrainingJob
from aihero.research.finetuning.infer import BatchInferenceJobRunner
from aihero.research.finetuning.train import TrainingJobRunner
from fire import Fire
def train(training_config_file: str = "/mnt/config/training/config.yaml") -> None:
"""Run Training."""
training_config = TrainingJob.load(training_config_file)
TrainingJobRunner(training_config, is_distributed=int(os.getenv("WORLD_SIZE", 1)) > 1).run()
def infer(batch_inference_config_file: str = "/mnt/config/batch_inference/config.yaml") -> None:
"""Run Batch Inference."""
batch_inference_config = BatchInferenceJob.load(batch_inference_config_file)
BatchInferenceJobRunner(batch_inference_config).run()
if __name__ == "__main__":
Fire({"train": train, "infer": infer})