Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move model loader functionality to augmentation #119

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

# Standard
from typing import Dict
from typing import Dict, Tuple

# Third Party
from fms_acceleration import AccelerationPlugin
from transformers import AutoModelForCausalLM
from peft import LoraConfig
from transformers import TrainingArguments
import torch

# Local
Expand Down Expand Up @@ -52,21 +53,27 @@ def __init__(self, configurations: Dict[str, Dict]):
)

@property
def requires_custom_loading(self):
def requires_augmentation(self):
return True

def model_loader(self, model_name: str, **kwargs):

# load the model
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)

def augmentation(
self,
model,
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):
rank, world_size = 0, 1
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()

# shard the MOE, and store the component names, eventually needed
# to configure the FSDP
if not hasattr(model.config, "name_or_path") or not model.config.name_or_path:
raise ValueError(
"The model configuration is missing the 'name_or_path' attribute."
)

model_name = model.config.name_or_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say add a check for the prescence of name_or_path in model.config, and if not there, raise a ValueError explaining that for scattermoe, we require a name_or_path to point to the model in the config


self._moe_component_module_names = prepare_scattermoe(
model,
checkpoint_name_or_path=model_name,
Expand All @@ -75,13 +82,7 @@ def model_loader(self, model_name: str, **kwargs):
ep_degree=self._ep_degree,
mixed_precision=False, # Currently this is hardcoded to OFF
)

# NOTE: there is currently no good way to get the mixed precision
# flag from train_args. It will be better to handle this if
# when we move the sharding to augmentation.
# https://github.com/foundation-model-stack/fms-acceleration/issues/103

return model
return model, modifiable_args

def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator=None
Expand Down
Loading