forked from rosinality/stylegan2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4708d72
commit 677b91e
Showing
8 changed files
with
1,121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
wandb/ | ||
*.lmdb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from io import BytesIO | ||
|
||
import lmdb | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class MultiResolutionDataset(Dataset): | ||
def __init__(self, path, transform, resolution=256): | ||
self.env = lmdb.open( | ||
path, | ||
max_readers=32, | ||
readonly=True, | ||
lock=False, | ||
readahead=False, | ||
meminit=False, | ||
) | ||
|
||
if not self.env: | ||
raise IOError('Cannot open lmdb dataset', path) | ||
|
||
with self.env.begin(write=False) as txn: | ||
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) | ||
|
||
self.resolution = resolution | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, index): | ||
with self.env.begin(write=False) as txn: | ||
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') | ||
img_bytes = txn.get(key) | ||
|
||
buffer = BytesIO(img_bytes) | ||
img = Image.open(buffer) | ||
img = self.transform(img) | ||
|
||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import math | ||
import pickle | ||
|
||
import torch | ||
from torch import distributed as dist | ||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
def get_rank(): | ||
if not dist.is_available(): | ||
return 0 | ||
|
||
if not dist.is_initialized(): | ||
return 0 | ||
|
||
return dist.get_rank() | ||
|
||
|
||
def synchronize(): | ||
if not dist.is_available(): | ||
return | ||
|
||
if not dist.is_initialized(): | ||
return | ||
|
||
world_size = dist.get_world_size() | ||
|
||
if world_size == 1: | ||
return | ||
|
||
dist.barrier() | ||
|
||
|
||
def get_world_size(): | ||
if not dist.is_available(): | ||
return 1 | ||
|
||
if not dist.is_initialized(): | ||
return 1 | ||
|
||
return dist.get_world_size() | ||
|
||
|
||
def reduce_sum(tensor): | ||
if not dist.is_available(): | ||
return tensor | ||
|
||
if not dist.is_initialized(): | ||
return tensor | ||
|
||
tensor = tensor.clone() | ||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | ||
|
||
return tensor | ||
|
||
|
||
def all_gather(data): | ||
world_size = get_world_size() | ||
|
||
if world_size == 1: | ||
return [data] | ||
|
||
buffer = pickle.dumps(data) | ||
storage = torch.ByteStorage.from_buffer(buffer) | ||
tensor = torch.ByteTensor(storage).to('cuda') | ||
|
||
local_size = torch.IntTensor([tensor.numel()]).to('cuda') | ||
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] | ||
dist.all_gather(size_list, local_size) | ||
size_list = [int(size.item()) for size in size_list] | ||
max_size = max(size_list) | ||
|
||
tensor_list = [] | ||
for _ in size_list: | ||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) | ||
|
||
if local_size != max_size: | ||
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') | ||
tensor = torch.cat((tensor, padding), 0) | ||
|
||
dist.all_gather(tensor_list, tensor) | ||
|
||
data_list = [] | ||
|
||
for size, tensor in zip(size_list, tensor_list): | ||
buffer = tensor.cpu().numpy().tobytes()[:size] | ||
data_list.append(pickle.loads(buffer)) | ||
|
||
return data_list | ||
|
||
|
||
def reduce_loss_dict(loss_dict): | ||
world_size = get_world_size() | ||
|
||
if world_size < 2: | ||
return loss_dict | ||
|
||
with torch.no_grad(): | ||
keys = [] | ||
losses = [] | ||
|
||
for k in sorted(loss_dict.keys()): | ||
keys.append(k) | ||
losses.append(loss_dict[k]) | ||
|
||
losses = torch.stack(losses, 0) | ||
dist.reduce(losses, dst=0) | ||
|
||
if dist.get_rank() == 0: | ||
losses /= world_size | ||
|
||
reduced_losses = {k: v for k, v in zip(keys, losses)} | ||
|
||
return reduced_losses |
Oops, something went wrong.