Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Refactor to Introduce Backend Abstraction #2011

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
50fcfc0
modify parallelization strategy
zhenglongjiepheonix Aug 14, 2024
4114d3b
only support model id in api now
zhenglongjiepheonix Aug 14, 2024
c689402
more comments
zhenglongjiepheonix Aug 15, 2024
252c3b7
more comments
zhenglongjiepheonix Aug 16, 2024
1be77ed
Merge remote-tracking branch 'upstream/main' into longjie/generalize_…
zhenglongjiepheonix Aug 16, 2024
22d6766
address comments
zhenglongjiepheonix Aug 20, 2024
febac9b
remove idle runner
zhenglongjiepheonix Aug 20, 2024
bf99175
fix
zhenglongjiepheonix Aug 20, 2024
4d9d036
format
zhenglongjiepheonix Aug 20, 2024
44a87f4
more comments
zhenglongjiepheonix Aug 26, 2024
513d516
generalize api & add backend abstraction
zhenglongjiepheonix Aug 26, 2024
8335a35
fix
zhenglongjiepheonix Aug 27, 2024
d051217
copyright
zhenglongjiepheonix Aug 27, 2024
6b03855
fix api
zhenglongjiepheonix Aug 28, 2024
6466ccc
Merge remote-tracking branch 'upstream/main' into longjie/add_backend…
zhenglongjiepheonix Aug 29, 2024
b4166ac
move weights intialization inside post process
zhenglongjiepheonix Aug 29, 2024
576104c
seperate meta update and parallel layer construction
zhenglongjiepheonix Aug 30, 2024
8bbc2e9
move weight intialization & binding inside backend
zhenglongjiepheonix Sep 2, 2024
d68df89
add weights tying for nanotron backend
zhenglongjiepheonix Sep 2, 2024
c752e29
fix
zhenglongjiepheonix Sep 3, 2024
82d1cf9
resolve
zhenglongjiepheonix Sep 3, 2024
3a1a195
fix
zhenglongjiepheonix Sep 3, 2024
b5b371f
fix conflict
zhenglongjiepheonix Sep 20, 2024
0ff39bb
address comments
zhenglongjiepheonix Sep 20, 2024
5137f68
address comments
zhenglongjiepheonix Sep 20, 2024
9dd77de
fix
zhenglongjiepheonix Sep 20, 2024
a375b6d
fix
zhenglongjiepheonix Sep 20, 2024
40880a3
fix
zhenglongjiepheonix Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import importlib
import os
from functools import partial
from typing import Callable, List
from typing import Callable, List, Optional, Type

import torch
import torch.nn as nn
from torch.fx import GraphModule
from transformers import AutoConfig
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline
from .utils import (
MetaAwareMethodsPatcher,
download_model_from_hf,
Expand All @@ -34,32 +34,40 @@

def parallelize_backend(
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config
) -> GraphModule:
) -> nn.Module:
ctx.example_inputs = example_inputs
pass_pipeline = build_parallel_pass_pipeline()
pass_pipeline = ctx.backend.init_parallelization_pass_pipeline()
graph_module = ctx.backend.pre_process(graph_module=graph_module, ctx=ctx, config=config)
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
finalized_module = ctx.backend.post_process(graph_module=graph_module, ctx=ctx, config=config)
ctx.compile_times += 1
ctx.last_optimized_graph_module = graph_module
return graph_module
ctx.last_optimized_module = finalized_module
return finalized_module


def parallelize_model(
model: str,
parallel_ctx: ParallelExecutionCtx,
*model_args,
model_id_or_path: Optional[str] = None,
model_cls: Optional[Type[PreTrainedModel]] = None,
model_config: Optional[PretrainedConfig] = None,
**kwargs,
) -> Callable:
"""
API for automatic model parallelism through Pytorch FX.

Args:
model (`str`):
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
parallel_ctx (`ParallelExecutionCtx`):
Parallel execution context containing process groups the current process belongs to.
*model_args (`Any`):
Additional postional arguments for intializing the model if a model id is passed.
model_id_or_path (`Optional[str]`, defaults to `None`):
Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights
of the model.
model_cls (`Optional[Type[PreTrainedModel]]`, defaults to `None`):
Model class in transformers library, i.e, `LlamaForCausalLM`.
model_config (`Optional[PretrainedConfig]`, defaults to `None`):
Model config to intialize the model.
revision (`str`, defaults to `main`):
Model revision for weights downloading if a model id is passed.
cache_dir (`Optional[str]`, defaults to `None`):
Expand All @@ -82,29 +90,39 @@ def parallelize_model(
setattr(parallel_config, k, v)
kwargs.pop(k)

is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
if model_id_or_path is not None and (model_cls is not None or model_config is not None):
raise ValueError(
"Can not accept passing in all of `model_id_or_path`, `model_cls` and `model_config`. Only specify "
"`model_id_or_path` or `model_cls` and `model_config` because there might be conflicts otherwise"
)
else:
hf_folder = model

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)
# Init model instance
if model_id_or_path is not None:
is_local = os.path.isdir(model_id_or_path)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model_id_or_path,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model_id_or_path

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)

# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])
# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])

if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)
if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model_id_or_path, cache_dir, hf_folder)
elif model_cls is None or model_config is None:
raise ValueError("must provide `model_cls` and `model_config` in the case of not providing `model_id_or_path`")

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
Expand Down
15 changes: 15 additions & 0 deletions optimum/fx/parallelization/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import Backend, DefaultBackend
Loading
Loading