From d34bcbdb90d86e1f89af56b393af66c4e5df1d3d Mon Sep 17 00:00:00 2001 From: bilzard <36561962+bilzard@users.noreply.github.com> Date: Thu, 3 Feb 2022 08:07:54 +0900 Subject: [PATCH 1/4] Load checkpoint on CPU instead of on GPU --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 2a973fb7164b..ca48d4a79a59 100644 --- a/train.py +++ b/train.py @@ -120,7 +120,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location=device) # load checkpoint + ckpt = torch.load( + weights, map_location=lambda storage, _: storage + ) # Load all tensors onto the CPU, using a function to avoid memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 From e0f62c35c0a9aafb72066ee34a703d892f765d28 Mon Sep 17 00:00:00 2001 From: bilzard Date: Fri, 4 Feb 2022 12:18:13 +0900 Subject: [PATCH 2/4] refactor: simplify code --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index ca48d4a79a59..63310bbb128e 100644 --- a/train.py +++ b/train.py @@ -120,9 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load( - weights, map_location=lambda storage, _: storage - ) # Load all tensors onto the CPU, using a function to avoid memory leak + ckpt = torch.load(weights, map_location=torch.device("cpu")) # load all tensors onto the CPU to avoid memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 From 57e7aec820499982fdf49923c7844a4af98500fc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 4 Feb 2022 18:29:46 +0100 Subject: [PATCH 3/4] Cleanup --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 63310bbb128e..21ef6b6b7d69 100644 --- a/train.py +++ b/train.py @@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location=torch.device("cpu")) # load all tensors onto the CPU to avoid memory leak + ckpt = torch.load(weights, map_location='cpu') # load to CPU to avoid CUDA memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 From 6a30669b7546caab49f595670fc532f548c0cfed Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 4 Feb 2022 18:30:12 +0100 Subject: [PATCH 4/4] Update train.py --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 21ef6b6b7d69..56103b8d4202 100644 --- a/train.py +++ b/train.py @@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location='cpu') # load to CPU to avoid CUDA memory leak + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32