Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mlflow dependency #23

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ install:
just python -m pip install types-requests

format:
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party --exclude ultravox/model/gazelle
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --remove-all-unused-imports --quiet --in-place -r --exclude third_party
. ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --force-single-line-imports
. ./activate ${VENV_NAME} && black ${PROJECT_DIR}

check:
. ./activate ${VENV_NAME} && black ${PROJECT_DIR} --check
. ./activate ${VENV_NAME} && isort ${PROJECT_DIR} --check --force-single-line-imports
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party --exclude ultravox/model/gazelle
. ./activate ${VENV_NAME} && autoflake ${PROJECT_DIR} --check --quiet --remove-all-unused-imports -r --exclude third_party
. ./activate ${VENV_NAME} && mypy ${PROJECT_DIR}

test *ARGS=".":
Expand Down
4 changes: 2 additions & 2 deletions mcloud.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Gazelle POC training configuration
# Ultravox POC training configuration

name: gazelle-poc
name: ultravox
image: mosaicml/composer:latest
compute:
gpus: 8
Expand Down
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
[mypy]
ignore_missing_imports = True

[mypy-ultravox/model/gazelle.*]
ignore_errors = True
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,3 @@ jiwer
# Monitoring
tensorboardx
wandb
neptune
mlflow

2 changes: 1 addition & 1 deletion ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]:
sample.messages, tokenize=False
)

# Process audio and text using GazelleProcessor.
# Process audio and text using UltravoxProcessor.
# Audio is expanded to be a [C x M] array, although C=1 for mono audio.
audio = (
np.expand_dims(sample.audio, axis=0) if sample.audio is not None else None
Expand Down
2 changes: 1 addition & 1 deletion ultravox/training/configs/llama3_whisper.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SLM with gazelle & llama3
# SLM with ultravox & llama3
exp_name: "llama3_whisper_s"

# Make sure to accept the license agreement on huggingface hub
Expand Down
33 changes: 0 additions & 33 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import List, Optional

import datasets as hf_datasets
import mlflow
import safetensors.torch
import simple_parsing
import torch
Expand All @@ -20,7 +19,6 @@

from ultravox.data import datasets
from ultravox.inference import infer
from ultravox.inference import ultravox_infer
from ultravox.model import ultravox_config
from ultravox.model import ultravox_model
from ultravox.model import ultravox_processing
Expand All @@ -33,17 +31,6 @@
OUTPUT_EXAMPLE = {"text": "Hello, world!"}


class GazelleMlflowWrapper(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input):
sample = datasets.VoiceSample.from_prompt_and_buf(
model_input["text"], model_input["audio"]
)
return self.inference.infer(sample)

def load_context(self, context):
self.inference = ultravox_infer.UltravoxInference(context.artifacts["model_id"])


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)

Expand Down Expand Up @@ -143,12 +130,6 @@ def main() -> None:
dir="runs",
)

# Starting MLflow; we need to set the experiment name before training starts.
if "mlflow" in args.report_logs_to and is_master:
mlflow.set_tracking_uri("runs/mlruns")
db_exp_name = f"/Shared/{args.exp_name}"
mlflow.set_experiment(db_exp_name)

if args.model_load_dir:
logging.info(f"Loading model state dict from {args.model_load_dir}")
load_path = args.model_load_dir
Expand Down Expand Up @@ -274,20 +255,6 @@ def main() -> None:
)
trainer.train()
trainer.save_model(args.output_dir)
if "mlflow" in args.report_logs_to and is_master:
signature = mlflow.models.signature.infer_signature(
INPUT_EXAMPLE, OUTPUT_EXAMPLE
)
model_info = mlflow.pyfunc.log_model(
python_model=GazelleMlflowWrapper(),
artifact_path="model",
pip_requirements="requirements.txt",
registered_model_name="ultravox",
input_example=INPUT_EXAMPLE,
signature=signature,
)
logging.info(f"Model logged to MLflow: {model_info.model_uri}")

t_end = datetime.now()
logging.info(f"end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")
Expand Down
Loading