Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mingyuanm/add back fp8 support to sd #9070

Merged
merged 57 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e91a66d
update branch
ericharper Jan 29, 2024
305ad9c
Add dist ckpt support for regular optimizers (#7749)
mikolajblaz Jan 31, 2024
40da002
Pin lhotse=1.19.2 in r1.23.0 (#8303)
pzelasko Feb 1, 2024
d3bad4b
Cache Aware Streaming tutorial notebook (#8296)
erastorgueva-nv Feb 1, 2024
17f09e4
fix path location and branch (#8304)
nithinraok Feb 2, 2024
991dad9
add deallocate pipeline output optimization (#8279)
JimmyZhang12 Feb 2, 2024
e9320ed
Fix memory leak caused by context parallelism hanging references by o…
JimmyZhang12 Feb 2, 2024
8b18cfc
remove assertion (#8302)
dimapihtar Feb 2, 2024
d9f1409
Update PEFT Doc (#8262)
cuichenx Feb 3, 2024
a592517
Attention encoder-decoder models for multiple speech-to-text tasks …
titu1994 Feb 3, 2024
c3c766e
Multimodal r1.23.0 bug fix (#8315)
yaoyu-33 Feb 6, 2024
1434979
Fixes for MoE parameter passing & use of AutoTokenizer/Model for mist…
akoumpa Feb 6, 2024
ec8f413
Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP…
erhoo82 Feb 6, 2024
50864db
Remove asr webapp (#8347)
titu1994 Feb 6, 2024
498e9e4
remove _target_ at model level in aed config (#8351)
krishnacpuvvada Feb 6, 2024
2f72846
Add change_vocabulary and save_tokenizers() support to Multitask ASR …
titu1994 Feb 7, 2024
931c53c
Change default (#8371)
titu1994 Feb 8, 2024
0e13348
bug fix in fast-conformer-aed.yaml and adding jenkins test for speech…
krishnacpuvvada Feb 9, 2024
138a7ab
Enable megatron core loggers for GPT pretraining (#8354)
ashbhandare Feb 9, 2024
4ee9c58
mcore ds fix (#8283)
dimapihtar Feb 9, 2024
02ec761
Add Finetuning tutorial with HF Datasets (#8356)
nithinraok Feb 9, 2024
88d7b21
release updates (#8378)
dimapihtar Feb 9, 2024
400c4a1
MCore dataset compatibility for tokenizers (#8390)
vysarge Feb 11, 2024
3112091
Mcore customization doc (#8298)
HuiyingLi Feb 12, 2024
68eba36
wer fix (#8404)
tbartley94 Feb 12, 2024
5b8f18c
updated link to pubmed (#8402)
nithinraok Feb 13, 2024
0f7b49b
Update NFA video download link (#8406)
erastorgueva-nv Feb 13, 2024
f897a77
revert changes (#8410)
cuichenx Feb 13, 2024
371de5b
Fix dreambooth data sampler issue (#8400)
yaoyu-33 Feb 13, 2024
98186c2
Fixed errors in the CTM gen functions (#8416)
tango4j Feb 14, 2024
8689bc0
add ensemble decoding fix (#8427)
nithinraok Feb 15, 2024
770f73b
SDE bugfix log (#8430)
Jorjeous Feb 15, 2024
05122bd
mcore customization doc minor fix (#8421)
HuiyingLi Feb 16, 2024
2e77f20
NeMo-Mistral to HF converter bugfix. (#8353)
akoumpa Feb 16, 2024
9588494
Fixing mcore bert for TP, PP and SP (#8336)
shanmugamr1992 Feb 16, 2024
71ce00c
Add settings to suppress bf16 compile errors in CI on V100 (#8481)
athitten Feb 22, 2024
c98b9c1
MoE parameter passing (#8255)
akoumpa Feb 23, 2024
a836fce
Update k2 version (#8478) (#8492)
artbataev Feb 23, 2024
0dc8a19
Add fp8 support for SD/Update notebook paths (#8489)
Victor49152 Feb 25, 2024
1d80d00
pin to 0.5.0 (#8465)
ericharper Feb 26, 2024
fcf1044
Update NeMo Multimodal Requirements (#8515)
yaoyu-33 Feb 26, 2024
d2283e3
update github raw content link (#8517)
cuichenx Feb 26, 2024
e6b7354
Add dep notice for notebooks (#8522)
ericharper Feb 27, 2024
ae9a2aa
Revert FP8 integration (#8520)
Victor49152 Feb 27, 2024
e772dbf
Update data prep notebook (#8532)
Victor49152 Feb 27, 2024
e65d3de
Add back fp8 support
Victor49152 Feb 29, 2024
82911a4
SD-FP8: fix the bug of normalization location
Mar 6, 2024
e7b29ae
map potential FP8 ckpt to FP16
Victor49152 Apr 25, 2024
11bfe29
Merge branch 'main' into mingyuanm/add_back_fp8_support_to_sd
Victor49152 Apr 26, 2024
712fc8a
Add TE fp8 training
Victor49152 Apr 29, 2024
d5c5686
Only overwrite unet precision when self.megatron_amp_O2 is true
Victor49152 Apr 29, 2024
cab6816
New structure is now compatible with old ckpts
Victor49152 Apr 29, 2024
f70a50f
Add support on mapping old unet checkpoint to new structure and FP8 s…
Victor49152 Apr 30, 2024
14135da
Merge branch 'main' into mingyuanm/add_back_fp8_support_to_sd
Victor49152 Apr 30, 2024
9bbcbb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
b5e2ef2
Sync with main branch
Victor49152 Apr 30, 2024
601b4b1
Merge remote-tracking branch 'origin/mingyuanm/add_back_fp8_support_t…
Victor49152 Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ model:
precision: ${trainer.precision}
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 1 # limited by GPU memory
global_batch_size: 1 # will use more micro batches to reach global batch size
micro_batch_size: 16 # limited by GPU memory
global_batch_size: 16 # will use more micro batches to reach global batch size
native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16


Expand Down Expand Up @@ -97,15 +97,15 @@ model:
unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained: #/ckpts/nemo-v1-2.ckpt
from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
Expand All @@ -121,6 +121,7 @@ model:
use_flash_attention: True
unet_precision: fp32
resblock_gn_groups: 32
use_te_fp8: False

first_stage_config:
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL
Expand All @@ -140,30 +141,30 @@ model:
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
restore_from_path: /ckpts/openai.nemo
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
version: openai/clip-vit-large-patch14
device: cuda
freeze: True
layer: "last"
# For compatibility of history version that uses HF clip model
# _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
# version: openai/clip-vit-large-patch14
# device: cuda
# max_length: 77
max_length: 77
# _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
# restore_from_path: /ckpts/openai-old.nemo
# device: cuda
# freeze: True
# layer: "last"



# miscellaneous
seed: 1234
resume_from_checkpoint: null # manually set the checkpoint file to load from
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
ddp_overlap: True # True for using PyTorch DDP overlap.
ddp_overlap: False # True for using PyTorch DDP overlap.

optim:
name: fused_adam
Expand Down Expand Up @@ -191,7 +192,7 @@ model:
synthetic_data_length: 10000
train:
dataset_path:
- /datasets/coyo/test.pkl
- /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl
augmentations:
resize_smallest_side: 512
center_crop_h_w: 512, 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg):
model_cfg.unet_config.use_flash_attention = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
model_cfg.first_stage_config._target_ = (
'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL'
)

torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# megatron_amp_O2 is not yet supported in diffusion models
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)

if self.cfg.precision in ['16', 16, 'bf16']:
if self.megatron_amp_O2 and self.cfg.precision in ['16', 16, 'bf16']:
self.model_parallel_config.enable_autocast = False
if not hasattr(self.cfg.unet_config, 'unet_precision') or not '16' in str(
self.cfg.unet_config.unet_precision
Expand Down
66 changes: 53 additions & 13 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from inspect import isfunction

import torch
Expand All @@ -21,6 +22,13 @@
from torch import einsum, nn
from torch._dynamo import disable

if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
from nemo.gn_native import GroupNormNormlization as GroupNorm
else:
from apex.contrib.group_norm import GroupNorm

from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
Expand Down Expand Up @@ -96,13 +104,19 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
if use_te:
activation = 'gelu' if not glu else 'geglu'
# TODO: more parameters to be confirmed, dropout, seq_length
self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
else:
norm = nn.LayerNorm(dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))

def forward(self, x):
return self.net(x)
Expand Down Expand Up @@ -225,10 +239,15 @@ def __init__(
dropout=0.0,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()

self.inner_dim = dim_head * heads
if context_dim is None:
self.is_self_attn = True
else:
self.is_self_attn = False # cross-attention
context_dim = default(context_dim, query_dim)
# make attention part be aware of self-attention/cross-attention
self.context_dim = context_dim
Expand All @@ -238,10 +257,19 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)

self.use_te = use_te
if use_te:
return_layernorm_output = True if self.is_self_attn else False
self.norm_to_q = LayerNormLinear(
query_dim, self.inner_dim, bias=False, return_layernorm_output=return_layernorm_output
)
else:
self.norm = nn.LayerNorm(query_dim)
self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False)

self.to_out = nn.Sequential(
LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout)
)
Expand All @@ -262,8 +290,18 @@ def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_cr
# add additional token
x = torch.cat([additional_tokens, x], dim=1)

q = self.to_q(x)
context = default(context, x)
if self.use_te:
q_out = self.norm_to_q(x)
if self.is_self_attn:
q, ln_out = q_out
context = default(context, ln_out)
else:
q = q_out
context = default(context, x)
else:
x = self.norm(x)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

Expand Down Expand Up @@ -351,6 +389,7 @@ def __init__(
use_flash_attention=False,
disable_self_attn=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
Expand All @@ -362,8 +401,9 @@ def __init__(
use_flash_attention=use_flash_attention,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
Expand All @@ -372,10 +412,8 @@ def __init__(
dropout=dropout,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint

def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
Expand All @@ -397,15 +435,15 @@ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_at
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
self.norm1(x),
x,
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
x = self.attn2(x, context=context, additional_tokens=additional_tokens) + x
x = self.ff(x) + x
return x


Expand All @@ -431,6 +469,7 @@ def __init__(
use_checkpoint=False,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
logging.info(
Expand Down Expand Up @@ -473,6 +512,7 @@ def __init__(
use_flash_attention=use_flash_attention,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
)
for d in range(depth)
]
Expand Down
Loading
Loading