Skip to content

Commit

Permalink
Initial implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Dec 14, 2019
1 parent 4708d72 commit 677b91e
Show file tree
Hide file tree
Showing 8 changed files with 1,121 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

wandb/
*.lmdb/
1 change: 1 addition & 0 deletions checkpoint/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pt
40 changes: 40 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from io import BytesIO

import lmdb
from PIL import Image
from torch.utils.data import Dataset


class MultiResolutionDataset(Dataset):
def __init__(self, path, transform, resolution=256):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)

if not self.env:
raise IOError('Cannot open lmdb dataset', path)

with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

self.resolution = resolution
self.transform = transform

def __len__(self):
return self.length

def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
img_bytes = txn.get(key)

buffer = BytesIO(img_bytes)
img = Image.open(buffer)
img = self.transform(img)

return img
114 changes: 114 additions & 0 deletions distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import math
import pickle

import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler


def get_rank():
if not dist.is_available():
return 0

if not dist.is_initialized():
return 0

return dist.get_rank()


def synchronize():
if not dist.is_available():
return

if not dist.is_initialized():
return

world_size = dist.get_world_size()

if world_size == 1:
return

dist.barrier()


def get_world_size():
if not dist.is_available():
return 1

if not dist.is_initialized():
return 1

return dist.get_world_size()


def reduce_sum(tensor):
if not dist.is_available():
return tensor

if not dist.is_initialized():
return tensor

tensor = tensor.clone()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

return tensor


def all_gather(data):
world_size = get_world_size()

if world_size == 1:
return [data]

buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')

local_size = torch.IntTensor([tensor.numel()]).to('cuda')
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)

tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))

if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
tensor = torch.cat((tensor, padding), 0)

dist.all_gather(tensor_list, tensor)

data_list = []

for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))

return data_list


def reduce_loss_dict(loss_dict):
world_size = get_world_size()

if world_size < 2:
return loss_dict

with torch.no_grad():
keys = []
losses = []

for k in sorted(loss_dict.keys()):
keys.append(k)
losses.append(loss_dict[k])

losses = torch.stack(losses, 0)
dist.reduce(losses, dst=0)

if dist.get_rank() == 0:
losses /= world_size

reduced_losses = {k: v for k, v in zip(keys, losses)}

return reduced_losses
Loading

0 comments on commit 677b91e

Please sign in to comment.