Skip to content

Commit

Permalink
part-2 cherry from: add paddle.async_save to reduce time cost by chec…
Browse files Browse the repository at this point in the history
…kpoint saving (PaddlePaddle#55115)

* add paddle.async_save to reduce time cost by checkpoint saving

* adapt save_for_auto_inference to paddle.async_save

* modify UT

* modify UT

* fix on cpu only version

* revert commit on save_auto_inference

* fix threading
  • Loading branch information
SylarTiaNII authored and wentaoyu committed Nov 23, 2023
1 parent 324ff7b commit 6eddfff
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@

from .tensor.einsum import einsum

from .framework import async_save, clear_async_save_task_queue # noqa: F401

from .framework.random import (
seed,
get_cuda_rng_state,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IPUPlace,
XPUPlace,
)
from .io import async_save, clear_async_save_task_queue # noqa: F401
from ..base.dygraph import base, to_variable # noqa: F401
from ..base.dygraph.base import disable_dygraph as enable_static # noqa: F401
from ..base.dygraph.base import enable_dygraph as disable_static # noqa: F401
Expand Down
75 changes: 75 additions & 0 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pickle
import sys
import threading
import warnings
from collections.abc import Iterable

Expand Down Expand Up @@ -48,6 +49,80 @@
)

__all__ = []
async_save_queue = []


def clear_async_save_task_queue():
'''
wait until all async save task to be done.
'''
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join()


def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
'''
async version of paddle.save.
Note:
currently only support dygraph mode.
Note:
any argument passed through configs will be overrided by default setting.
Args:
obj(Object) : The object to be saved.
path(str|BytesIO) : The path/buffer of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 4
sync_other_task(bool) : Determine whether to wait other async save task to be finished before this one be put in queue.
**configs(dict, optional): compatible argument to paddle.save, but will be overrided by default setting.
Examples:
.. code-block:: python
:name: code-example-1
import paddle
emb = paddle.nn.Embedding(10, 10)
layer_state_dict = emb.state_dict()
# call paddle.async_save with the same style of paddle.save
paddle.async_save(layer_state_dict, "emb.pdparams")
for i in range(10):
# do some calculations here
# wait if any async_save task has not been done
paddle.clear_async_task_queue()
'''
if in_dygraph_mode():
raise ValueError(
"async_save currently is not supported in static mode."
)
if len(configs) > 0:
warnings.warn(
"configs are not supported in async mode, will be overided by default settings."
)

# TODO: make this part async
def move_state_dict_to_cpu(sd):
for k, v in sd.items():
if isinstance(v, dict):
move_state_dict_to_cpu(v)
elif isinstance(v, core.eager.Tensor):
sd[k] = v.pin_memory() if core.is_compiled_with_cuda() else v

if isinstance(obj, dict):
move_state_dict_to_cpu(obj)
elif isinstance(obj, core.eager.Tensor):
obj = obj.pin_memory() if core.is_compiled_with_cuda() else obj
else:
# other types are currently not supported
raise TypeError(
f"currently async_save does not support this type: {type(obj)}"
)
if sync_other_task:
clear_async_save_task_queue()
t = threading.Thread(target=save, args=(obj, path, protocol))
t.start()
async_save_queue.append(t)


def _build_saved_state_dict(state_dict):
Expand Down
82 changes: 82 additions & 0 deletions test/legacy_test/test_paddle_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,88 @@ def test_save_load(self):
)


class TestAsyncSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()

# config seed
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
self.temp_dir = tempfile.TemporaryDirectory()

def tearDown(self):
self.temp_dir.cleanup()

def build_and_train_model(self):
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()

adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())

# create data loader
# TODO: using new DataLoader cause unknown Timeout on windows, replace it
loader = random_batch_reader()

# train
train(layer, loader, loss_fn, adam)

return layer, adam

def check_load_state_dict(self, orig_dict, load_dict):
for var_name, value in orig_dict.items():
load_value = (
load_dict[var_name].numpy()
if hasattr(load_dict[var_name], 'numpy')
else np.array(load_dict[var_name])
)
np.testing.assert_array_equal(value.numpy(), load_value)

def test_async_save_load(self):
layer, opt = self.build_and_train_model()

# save
layer_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdparams"
)
opt_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdopt"
)
layer_state_dict = layer.state_dict()
opt_state_dict = opt.state_dict()

paddle.async_save(
layer_state_dict, layer_save_path, sync_other_task=True
)
paddle.async_save(opt_state_dict, opt_save_path)
paddle.clear_async_save_task_queue()

# load
load_layer_state_dict = paddle.load(layer_save_path)
load_opt_state_dict = paddle.load(opt_save_path)

self.check_load_state_dict(layer_state_dict, load_layer_state_dict)
self.check_load_state_dict(opt_state_dict, load_opt_state_dict)

# test assertion on illegal object
some_tuple_obj = (1, 2, 3)
tuple_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.tuple.pdparams"
)
with self.assertRaises(TypeError):
paddle.async_save(some_tuple_obj, tuple_save_path)

# test assertion on static graph
paddle.enable_static()
static_save_path = os.path.join(
self.temp_dir.name,
"static_mode_test/test_paddle_async_save_load.linear.pdparams",
)
with self.assertRaises(ValueError):
paddle.async_save(layer_state_dict, static_save_path)


class TestSaveLoadProgram(unittest.TestCase):
def test_save_load_program(self):
paddle.enable_static()
Expand Down

0 comments on commit 6eddfff

Please sign in to comment.