diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index 8ce009d5458f..dff963590864 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -49,8 +49,8 @@ model: precision: ${trainer.precision} # specify micro_batch_size, global_batch_size, and model parallelism # gradient accumulation will be done automatically based on data_parallel_size - micro_batch_size: 1 # limited by GPU memory - global_batch_size: 1 # will use more micro batches to reach global batch size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16 @@ -97,15 +97,15 @@ model: unet_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel from_pretrained: #/ckpts/nemo-v1-2.ckpt - from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt + from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt image_size: 32 # unused in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: - - 4 - - 2 - - 1 + - 4 + - 2 + - 1 num_res_blocks: 2 channel_mult: - 1 @@ -121,6 +121,7 @@ model: use_flash_attention: True unet_precision: fp32 resblock_gn_groups: 32 + use_te_fp8: False first_stage_config: _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL @@ -140,22 +141,22 @@ model: - 4 - 4 num_res_blocks: 2 - attn_resolutions: [] + attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity cond_stage_config: - _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder - restore_from_path: /ckpts/openai.nemo + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + version: openai/clip-vit-large-patch14 device: cuda - freeze: True - layer: "last" - # For compatibility of history version that uses HF clip model - # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder - # version: openai/clip-vit-large-patch14 - # device: cuda - # max_length: 77 + max_length: 77 + # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder + # restore_from_path: /ckpts/openai-old.nemo + # device: cuda + # freeze: True + # layer: "last" + # miscellaneous @@ -163,7 +164,7 @@ model: resume_from_checkpoint: null # manually set the checkpoint file to load from apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - ddp_overlap: True # True for using PyTorch DDP overlap. + ddp_overlap: False # True for using PyTorch DDP overlap. optim: name: fused_adam @@ -191,7 +192,7 @@ model: synthetic_data_length: 10000 train: dataset_path: - - /datasets/coyo/test.pkl + - /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl augmentations: resize_smallest_side: 512 center_crop_h_w: 512, 512 diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py index f1e5e2872ea7..58e9e6e64470 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py @@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg): model_cfg.unet_config.use_flash_attention = False model_cfg.unet_config.from_pretrained = None model_cfg.first_stage_config.from_pretrained = None + model_cfg.first_stage_config._target_ = ( + 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL' + ) torch.backends.cuda.matmul.allow_tf32 = True trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 7023f57652b5..6ea4314ab71f 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -1674,7 +1674,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # megatron_amp_O2 is not yet supported in diffusion models self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) - if self.cfg.precision in ['16', 16, 'bf16']: + if self.megatron_amp_O2 and self.cfg.precision in ['16', 16, 'bf16']: self.model_parallel_config.enable_autocast = False if not hasattr(self.cfg.unet_config, 'unet_precision') or not '16' in str( self.cfg.unet_config.unet_precision diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index c92980d904f6..3fcab2127f4f 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import os from inspect import isfunction import torch @@ -21,6 +22,13 @@ from torch import einsum, nn from torch._dynamo import disable +if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1": + from nemo.gn_native import GroupNormNormlization as GroupNorm +else: + from apex.contrib.group_norm import GroupNorm + +from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP + from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -96,13 +104,19 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) + if use_te: + activation = 'gelu' if not glu else 'geglu' + # TODO: more parameters to be confirmed, dropout, seq_length + self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,) + else: + norm = nn.LayerNorm(dim) + project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -225,10 +239,15 @@ def __init__( dropout=0.0, use_flash_attention=False, lora_network_alpha=None, + use_te=False, ): super().__init__() self.inner_dim = dim_head * heads + if context_dim is None: + self.is_self_attn = True + else: + self.is_self_attn = False # cross-attention context_dim = default(context_dim, query_dim) # make attention part be aware of self-attention/cross-attention self.context_dim = context_dim @@ -238,10 +257,19 @@ def __init__( self.scale = dim_head ** -0.5 self.heads = heads - self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) + self.use_te = use_te + if use_te: + return_layernorm_output = True if self.is_self_attn else False + self.norm_to_q = LayerNormLinear( + query_dim, self.inner_dim, bias=False, return_layernorm_output=return_layernorm_output + ) + else: + self.norm = nn.LayerNorm(query_dim) + self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False) + self.to_out = nn.Sequential( LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout) ) @@ -262,8 +290,18 @@ def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_cr # add additional token x = torch.cat([additional_tokens, x], dim=1) - q = self.to_q(x) - context = default(context, x) + if self.use_te: + q_out = self.norm_to_q(x) + if self.is_self_attn: + q, ln_out = q_out + context = default(context, ln_out) + else: + q = q_out + context = default(context, x) + else: + x = self.norm(x) + q = self.to_q(x) + context = default(context, x) k = self.to_k(context) v = self.to_v(context) @@ -351,6 +389,7 @@ def __init__( use_flash_attention=False, disable_self_attn=False, lora_network_alpha=None, + use_te=False, ): super().__init__() self.disable_self_attn = disable_self_attn @@ -362,8 +401,9 @@ def __init__( use_flash_attention=use_flash_attention, context_dim=context_dim if self.disable_self_attn else None, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, @@ -372,10 +412,8 @@ def __init__( dropout=dropout, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) self.use_checkpoint = use_checkpoint def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): @@ -397,15 +435,15 @@ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_at def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): x = ( self.attn1( - self.norm1(x), + x, context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) - x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x - x = self.ff(self.norm3(x)) + x + x = self.attn2(x, context=context, additional_tokens=additional_tokens) + x + x = self.ff(x) + x return x @@ -431,6 +469,7 @@ def __init__( use_checkpoint=False, use_flash_attention=False, lora_network_alpha=None, + use_te=False, ): super().__init__() logging.info( @@ -473,6 +512,7 @@ def __init__( use_flash_attention=use_flash_attention, disable_self_attn=disable_self_attn, lora_network_alpha=lora_network_alpha, + use_te=use_te, ) for d in range(depth) ] diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 5ff0f6aa8a8a..b610f921a22a 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import os +import re from abc import abstractmethod from collections.abc import Iterable +from contextlib import nullcontext from functools import partial from typing import Iterable @@ -22,6 +25,9 @@ import torch as th import torch.nn as nn import torch.nn.functional as F + +# FP8 related import +import transformer_engine from apex.contrib.group_norm import GroupNorm from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer @@ -62,6 +68,34 @@ def convert_module_to_fp32(module, enable_norm_layers=False): convert_module_to_dtype(module, torch.float32, enable_norm_layers) +def convert_module_to_fp8(model): + def _set_module(model, submodule_key, module): + tokens = submodule_key.split('.') + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + setattr(cur_mod, tokens[-1], module) + + import copy + + from transformer_engine.pytorch.module import Linear as te_Linear + + for n, v in model.named_modules(): + if isinstance(v, torch.nn.Linear): + # if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue + logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}') + if v.bias is None: + is_bias = False + else: + is_bias = True + newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias) + newlinear.weight = copy.deepcopy(v.weight) + if v.bias is not None: + newlinear.bias = copy.deepcopy(v.bias) + _set_module(model, n, newlinear) + + class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py @@ -553,6 +587,7 @@ def __init__( unet_precision: str = "fp32", lora_network_alpha=None, timesteps=1000, + use_te_fp8: bool = False, ): super().__init__() from omegaconf.listconfig import ListConfig @@ -663,6 +698,7 @@ def __init__( input_block_chans = [model_channels] ch = model_channels ds = 1 + self.use_te_fp8 = use_te_fp8 for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ @@ -713,6 +749,7 @@ def __init__( use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=self.use_te_fp8, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -778,6 +815,7 @@ def __init__( use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, + use_te=self.use_te_fp8, lora_network_alpha=lora_network_alpha, ), ResBlock( @@ -844,6 +882,7 @@ def __init__( use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, + use_te=self.use_te_fp8, ) ) if level and i == self.num_res_blocks[level]: @@ -899,6 +938,34 @@ def __init__( self.convert_to_fp16() elif unet_precision == 'fp16': self.convert_to_fp16(enable_norm_layers=True) + elif self.use_te_fp8: + assert unet_precision != 'fp16', "fp8 training can't work with fp16 O2 amp recipe" + convert_module_to_fp8(self) + + fp8_margin = int(os.getenv("FP8_MARGIN", '0')) + fp8_interval = int(os.getenv("FP8_INTERVAL", '1')) + fp8_format = os.getenv("FP8_FORMAT", "hybrid") + fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024')) + fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max') + fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1' + + fp8_format_dict = { + 'hybrid': transformer_engine.common.recipe.Format.HYBRID, + 'e4m3': transformer_engine.common.recipe.Format.E4M3, + } + fp8_format = fp8_format_dict[fp8_format] + + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=fp8_margin, + interval=fp8_interval, + fp8_format=fp8_format, + amax_history_len=fp8_amax_history_len, + amax_compute_algo=fp8_amax_compute_algo, + override_linear_precision=(False, False, not fp8_wgrad), + ) + old_state_dict = self.state_dict() + new_state_dict = self.te_fp8_key_mapping(old_state_dict) + self.load_state_dict(new_state_dict, strict=False) self.unet_precision = unet_precision @@ -1000,8 +1067,65 @@ def _sdxl_embedding_mapping(self, sdxl_dict): res_dict[new_key_] = value_ return res_dict + def _legacy_unet_ckpt_mapping(self, unet_dict): + new_dict = {} + key_map = { + 'transformer_blocks.0.norm1.weight': 'transformer_blocks.0.attn1.norm.weight', + 'transformer_blocks.0.norm1.bias': 'transformer_blocks.0.attn1.norm.bias', + 'transformer_blocks.0.norm2.weight': 'transformer_blocks.0.attn2.norm.weight', + 'transformer_blocks.0.norm2.bias': 'transformer_blocks.0.attn2.norm.bias', + 'transformer_blocks.0.norm3.weight': 'transformer_blocks.0.ff.net.0.weight', + 'transformer_blocks.0.norm3.bias': 'transformer_blocks.0.ff.net.0.bias', + 'transformer_blocks.0.ff.net.0.proj.weight': 'transformer_blocks.0.ff.net.1.proj.weight', + 'transformer_blocks.0.ff.net.0.proj.bias': 'transformer_blocks.0.ff.net.1.proj.bias', + 'transformer_blocks.0.ff.net.2.weight': 'transformer_blocks.0.ff.net.3.weight', + 'transformer_blocks.0.ff.net.2.bias': 'transformer_blocks.0.ff.net.3.bias', + } + + pattern = re.compile(r'(input_blocks|output_blocks)\.[\d\w]+\.[\d\w]+\.') + pattern_middle_block = re.compile(r'middle_block\.[\d\w]+\.') + for old_key, value in unet_dict.items(): + match = pattern.match(old_key) + match_middle = pattern_middle_block.match(old_key) + if match or match_middle: + prefix = match.group(0) if match else match_middle.group(0) + suffix = old_key.split('.', 3)[-1] if match else old_key.split('.', 2)[-1] + if suffix in key_map: + new_key = prefix + key_map[suffix] + new_dict[new_key] = value + else: + new_dict[old_key] = value + else: + new_dict[old_key] = value + + return new_dict + + def te_fp8_key_mapping(self, unet_dict): + new_state_dict = {} + for key in unet_dict.keys(): + if 'extra_state' in key: + continue + + ### LayerNormLinear + # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias} + # norm_to_q.weight -> to_q.weight + new_key = key.replace('attn1.norm.', 'attn1.norm_to_q.layer_norm_') + new_key = new_key.replace('attn1.to_q.weight', 'attn1.norm_to_q.weight',) + new_key = new_key.replace('attn2.norm.', 'attn2.norm_to_q.layer_norm_') + new_key = new_key.replace('attn2.to_q.weight', 'attn2.norm_to_q.weight',) + + ### LayerNormMLP + # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} + # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias} + # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias} + new_key = new_key.replace('ff.net.0.', 'ff.net.layer_norm_') + new_key = new_key.replace('ff.net.1.proj.', 'ff.net.fc1_') + new_key = new_key.replace('ff.net.3.', 'ff.net.fc2_') + + new_state_dict[new_key] = unet_dict[key] + return new_state_dict + def _state_key_mapping(self, state_dict: dict): - import re res_dict = {} input_dict = {} @@ -1027,13 +1151,7 @@ def _state_key_mapping(self, state_dict: dict): mid_dict = self._mid_blocks_mapping(mid_dict) other_dict = self._other_blocks_mapping(other_dict) sdxl_dict = self._sdxl_embedding_mapping(sdxl_dict) - # key_list = state_dict.keys() - # key_str = " ".join(key_list) - # for key_, val_ in state_dict.items(): - # key_ = key_.replace("down_blocks", "input_blocks")\ - # .replace("up_blocks", 'output_blocks') - # res_dict[key_] = val_ res_dict.update(input_dict) res_dict.update(output_dict) res_dict.update(mid_dict) @@ -1046,6 +1164,7 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: state_dict = self._state_key_mapping(state_dict) + state_dict = self._legacy_unet_ckpt_mapping(state_dict) model_state_dict = self.state_dict() loaded_keys = [k for k in state_dict.keys()] @@ -1151,7 +1270,7 @@ def convert_to_fp16(self, enable_norm_layers=False): """ self.apply(lambda module: convert_module_to_fp16(module=module, enable_norm_layers=enable_norm_layers)) - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. @@ -1170,7 +1289,6 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - if self.unet_precision == "fp16-mixed" or self.unet_precision == "fp16": x = x.type(torch.float16) if context is not None: @@ -1197,6 +1315,13 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): else: return self.out(h) + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe, + ) if self.use_te_fp8 else nullcontext(): + out = self._forward(x, timesteps, context, y, **kwargs) + return out + class EncoderUNetModel(nn.Module): """ diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index a6f68f0666b5..0a030759fe9b 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -73,6 +73,7 @@ try: from apex.transformer.pipeline_parallel.utils import get_num_microbatches + from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam HAVE_APEX = True @@ -1057,6 +1058,31 @@ def should_process(key): new_state_dict[key_] = state_dict[key_] state_dict = new_state_dict + if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False: + # Mapping potential fp8 ckpt to fp16 model + # remove _extra_state in fp8 if there is. + new_state_dict = {} + for key in state_dict.keys(): + if 'extra_state' in key: + continue + + ### LayerNormLinear + # norm_to_q.layer_norm_{weight|bias} -> norm.{weight|bias} + # norm_to_q.weight -> to_q.weight + new_key = key.replace('norm_to_q.layer_norm_', 'norm.') + new_key = new_key.replace('norm_to_q.weight', 'to_q.weight') + + ### LayerNormMLP + # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} + # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias} + # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias} + new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.') + new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.') + new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.') + + new_state_dict[new_key] = state_dict[key] + state_dict = new_state_dict + return state_dict def _load_state_dict_from_disk(self, model_weights, map_location=None):