-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
36 lines (28 loc) · 1.02 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from steps.model_promoter import model_promoter
from steps.model_trainer import model_trainer
from steps.data_loader import data_loader
from steps.data_preprocessor import data_preprocessor
from steps.data_splitter import data_splitter
from steps.model_evaluator import model_evaluator
from typing import Optional
from uuid import UUID
from zenml import pipeline
from zenml.client import Client
from zenml.logger import get_logger
logger = get_logger(__name__)
@pipeline
def training(
model_type: Optional[str] = "sgd",
):
dataframe = data_loader(random_state=42)
encoded_df = data_preprocessor(dataframe)
dataset_trn, dataset_tst = data_splitter(encoded_df)
trained_models = model_trainer([model_type], dataset_trn, 'log_price')
score = model_evaluator(trained_models, dataset_trn, dataset_tst, model_type)
final_ans = model_promoter(score)
if final_ans:
logger.info(f"Model promoted, score{score}")
else:
logger.info("Not promoted")
if __name__ == "__main__":
training()