Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero3] how to get the model reconstructed for saving? #872

Closed
stas00 opened this issue Mar 18, 2021 · 4 comments · Fixed by #893
Closed

[zero3] how to get the model reconstructed for saving? #872

stas00 opened this issue Mar 18, 2021 · 4 comments · Fixed by #893

Comments

@stas00
Copy link
Collaborator

stas00 commented Mar 18, 2021

Under zero2 self.model.state_dict() returns a fp16 version of the model, under zero3 it returns some placeholder: with all the weights being just tensor([1.],, so how can we get the trained model out of deepspeed?

This is related to #800 but there under zero2 we at least had the fp16 version save-able, now there is no way at all. This is a total lock-in, unless I'm missing on some API that was added with zero3.

Ideally it'd be awesome if it were to reconstruct it directly on the disk, since that will ensure that there will be enough memory to do so.

To summarize different requests so far - users have 3 different needs:

  1. being able to leave DeepSpeed after training with it. So they should be able to take a DeepSpeed checkpoint and recover the full consolidated fp32 model in a single file
  2. same as (1) but fp16. Surely once can go from (1) to (2) easily, but I think it'd be much faster to get to (2) directly - I could be wrong. If I'm wrong than (1) is enough.
  3. being able to call deepspeed.consolidate_weights() in the rank0 process which would give users full non-partitioned weights back (perhaps with a bool arg of whether they want the fp16 or fp32 version). So now they can just save the model as they do with any other pytorch tools. This would only be practical for small-ish models. The key here is that while this would be somewhat costly they will be able to use their code almost w/o any change if they train in various ways and not just with deepspeed. I think this must happen on cpu, since it's unlikely gpus will have the memory for that. It probably will have to return a copy of the model with the consolidated weights, so that the user can continue training the original model. So probably something along the lines of Save ZeRO3 (partitioned) fp16 weights #882 but in addition the partitioning will need to be removed too. The result of this call will provide users an equivalent of what they have under zero2 at this moment (if it's fp16).

I think all 3 would be more or less the same code, with just different ways of using it - (3) using the existing deepspeed engine and not needing to access the filesystem, (1) and (2) w/o needing an engine and relying exclusively on the filesystem.

Thank you!

@jeffra, @tjruwase

@tjruwase
Copy link
Contributor

@stas00, do you have repro steps?

@stas00
Copy link
Collaborator Author

stas00 commented Mar 19, 2021

Here we go:

git clone https://github.com/huggingface/transformers
cd transformers

and then:

cat <<'EOT' > ds_config_zero3.json
{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 3,
        "cpu_offload": true,
        "cpu_offload_params": true,
        "cpu_offload_use_pin_memory" : true,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e8,
        "stage3_prefetch_bucket_size": 2e5,
        "stage3_param_persitance_threshold": 1e5,
        "reduce_bucket_size": 3e6,
        "prefetch_bucket_size": 3e6,
        "sub_group_size": 1e6
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-5,
            "betas": [0.8, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 500
        }
    },

    "steps_per_print": 2000,
    "wall_clock_breakdown": false
}
EOT

and finally:

PYTHONPATH=src deepspeed --num_gpus 2 examples/seq2seq/run_translation.py --model_name_or_path t5-small --output_dir /tmp/zero3 --overwrite_output_dir --max_train_samples 10 --max_val_samples 10 --max_source_length 12 --max_target_length 12 --val_max_target_length 12 --do_train --num_train_epochs 1 --per_device_train_batch_size 2 --learning_rate 3e-3 --warmup_steps 8 --predict_with_generate --logging_steps 0 --save_steps 2 --eval_steps 1 --group_by_length  --adafactor --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --deepspeed  ds_config_zero3.json

the results:

ls -l /tmp/zero3/checkpoint-2/
total 868K
-rw-rw-r-- 1 stas stas 1.4K Mar 18 17:27 config.json
drwxrwxr-x 2 stas stas 4.0K Mar 18 17:27 global_step2/
-rw-rw-r-- 1 stas stas   12 Mar 18 17:27 latest
-rw-rw-r-- 1 stas stas  61K Mar 18 17:27 pytorch_model.bin
-rw-rw-r-- 1 stas stas 1.8K Mar 18 17:27 special_tokens_map.json
-rw-rw-r-- 1 stas stas 774K Mar 18 17:27 spiece.model
-rw-rw-r-- 1 stas stas 1.9K Mar 18 17:27 tokenizer_config.json
-rw-rw-r-- 1 stas stas  346 Mar 18 17:27 trainer_state.json
-rw-rw-r-- 1 stas stas 2.4K Mar 18 17:27 training_args.bin

pytorch_model.bin needs to be 116MBs for fp16, twice as big if it's fp32 (This is t5-small)

@stas00
Copy link
Collaborator Author

stas00 commented Mar 25, 2021

Some variation of this will eventually be part of the DeepSpeed API, but if you need it sooner, here is how you get the consolidated fp16 state_dict, which once saved you can load as a normal pre-trained mode (except it'll be just fp16)

self is the deepspeed engine object.

def zero3_consolidated_fp16_state_dict(self):
    """

    This function restores a full non-partitioned state_dict with fp16 weights

    similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:

    1. consolidates the weights from different partitions on gpu0
    2. works on one layer at a time to require as little gpu0 memory as possible, by
    moving the already consolidated weights to cpu
    3. takes care to keep the shared params shared

    Returns:
        a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks

    """

    import deepspeed
    from collections import OrderedDict

    if not self.zero_optimization_partition_weights():
        raise ValueError("this function requires ZeRO-3 mode")

    state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
    shared_weights = {}
    def get_layer_state_dict(module, prefix=""):
        # gather one layer at a time to be memory-efficient
        with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False))):
            if torch.distributed.get_rank() == 0:
                for name, param in module.named_parameters(recurse=False):
                    if param is None:
                        continue
                    key = prefix + name
                    # for shared weights we want to make sure not to unshare them when copying to cpu
                    data_ptr_id = param.storage().data_ptr()
                    if data_ptr_id in shared_weights:
                        # shared weights
                        #print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`")
                        state_dict[key] = state_dict[ shared_weights[data_ptr_id] ]
                    else:
                        state_dict[key] = param.detach().cpu()
                        shared_weights[data_ptr_id] = key

                # now buffers - not sure if need to take care of potentially shared weights here
                for name, buf in module.named_buffers(recurse=False):
                    if buf is not None and name not in module._non_persistent_buffers_set:
                        state_dict[prefix + name] = buf.detach().cpu()

        for name, child in module.named_children():
            if child is not None:
                get_layer_state_dict(child, prefix + name + ".")

    #see_memory_usage("before get_layer_state_dict", force=True)
    # XXX: not sure about starting prefix? see pretrained load
    get_layer_state_dict(self.module, prefix="")
    #see_memory_usage("after get_layer_state_dict", force=True)

    return state_dict

@stas00
Copy link
Collaborator Author

stas00 commented Mar 26, 2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants