From 39f673469d67aef49627397e8445ce833c13835a Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Fri, 28 Jul 2023 07:57:21 +0000 Subject: [PATCH 1/9] added 2d flow, manual weights to losses, l2 diffusion loss, same hp for 2d and 3d flow experiments --- projects/super_res/config.py | 6 +- projects/super_res/config_mod_flow.py | 43 + .../super_res/model/autoreg_diffusion_mod.py | 10 +- .../model/autoreg_diffusion_mod_flow.py | 1491 +++++++++++++++++ projects/super_res/trainer_mod_flow.py | 66 + 5 files changed, 1608 insertions(+), 8 deletions(-) create mode 100644 projects/super_res/config_mod_flow.py create mode 100644 projects/super_res/model/autoreg_diffusion_mod_flow.py create mode 100755 projects/super_res/trainer_mod_flow.py diff --git a/projects/super_res/config.py b/projects/super_res/config.py index 80ebd24c08..a70ccced96 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -9,13 +9,13 @@ config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 6 -config.loss = "l1" +config.sampling_steps = 20 +config.loss = "l2" config.objective = "pred_v" config.lr = 8e-5 config.steps = 5000000 config.grad_acc = 1 -config.val_num_of_batch = 1 +config.val_num_of_batch = 2 config.save_and_sample_every = 5000 config.ema_decay = 0.995 config.amp = False diff --git a/projects/super_res/config_mod_flow.py b/projects/super_res/config_mod_flow.py new file mode 100644 index 0000000000..daa24191d1 --- /dev/null +++ b/projects/super_res/config_mod_flow.py @@ -0,0 +1,43 @@ +from ml_collections import config_dict + +#batch_size = 4 +config = config_dict.ConfigDict() + +config.dim = 64 +config.dim_mults = (1, 1, 2, 2, 3, 4) +config.learned_sinusoidal_cond = True, +config.random_fourier_features = True, +config.learned_sinusoidal_dim = 32 +config.diffusion_steps = 1500 +config.sampling_steps = 20 +config.loss = "l2" +config.objective = "pred_v" +config.lr = 8e-5 +config.steps = 5000000 +config.grad_acc = 1 +config.val_num_of_batch = 2 +config.save_and_sample_every = 5000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "mod_flow" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "quick": True +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index 7090bf50be..b000fd0eff 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -1031,7 +1031,7 @@ def p_losses(self, stack, hres, lres, ures, t, noise = None): loss2 = self.loss_fn(x_start, warped, reduction = 'none') loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') - return loss.mean() + loss1.mean() + loss2.mean() + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*1.0 def forward(self, lres, hres, *args, **kwargs): @@ -1325,13 +1325,13 @@ def train(self): ax1.set_ylabel("Density") ax1.set_yscale("log") - flow_d = np.zeros((1, num_samples, 3, img_size, img_size)) - for m in range(num_samples): + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) + for m in range(num_frames): flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) - flow_s = np.zeros((1, num_samples, 3, img_size, img_size)) + flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) sm = smap(None, fcmap) - for m in range(num_samples): + for m in range(num_frames): flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) diff --git a/projects/super_res/model/autoreg_diffusion_mod_flow.py b/projects/super_res/model/autoreg_diffusion_mod_flow.py new file mode 100644 index 0000000000..3a7e24a663 --- /dev/null +++ b/projects/super_res/model/autoreg_diffusion_mod_flow.py @@ -0,0 +1,1491 @@ +import os +import math +from pathlib import Path +from random import random +from functools import partial +from collections import namedtuple +from joblib import Parallel, delayed + +import numpy as np + +import torch +from torch import nn, einsum +import torch.nn.functional as F +import wandb + +import piq + +from kornia import filters +from torch.optim import Adam + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange + +from PIL import Image + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable as smap + +from tqdm.auto import tqdm +from ema_pytorch import EMA + +import flow_vis + +from accelerate import Accelerator + +from .network_swinir import SwinIR as context_net + +# constants + +ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) + +# helpers functions + +def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): + truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for i in range(256): + truth_cdf[i, :, :, :, :, :, :] = (truth <= i).astype('uint8') + pred_cdf = np.zeros((256, num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for j in range(256): + pred_cdf[j, :, :, :, :, :, :, :] = (pred <= j).astype('uint8') + red_pred_cdf = pred_cdf.mean(1) + temp = np.square(red_pred_cdf - truth_cdf) + temp_dz = temp.sum(0) + temp_dz_dd = temp_dz.mean(axis = (3, 4, 5)) + temp_dz_dd_dt = temp_dz_dd.mean(2) + return temp_dz_dd_dt.mean() + +def save_image(tensor, path): + im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) + im.save(path) + return None + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# flow modules + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + +# small helper modules + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + +def Downsample(dim, dim_out = None): + return nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), + nn.Conv2d(dim * 4, default(dim_out, dim), 1) + ) + +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class RandomOrLearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim, is_random = False): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q * self.scale + + sim = einsum('b h d i, b h d j -> b h i j', q, k) + attn = sim.softmax(dim = -1) + out = einsum('b h i j, b h d j -> b h i d', attn, v) + + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + +# model + +class Unet(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + self_condition = False, + resnet_block_groups = 8, + learned_variance = False, + learned_sinusoidal_cond = False, + random_fourier_features = False, + learned_sinusoidal_dim = 16 + ): + super().__init__() + + # determine dimensions + + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features + + if self.random_or_learned_sinusoidal_cond: + sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = SinusoidalPosEmb(dim) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), + block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x, time, context, x_self_cond = None): + if self.self_condition: + x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + + x = self.init_conv(x) + r = x.clone() + + t = self.time_mlp(time) + + h = [] + + count = 0 + + for block1, block2, attn, downsample in self.downs: + x = torch.cat((x, context[count]), dim = 1) + count += 1 + x = block1(x, t) + h.append(x) + + x = torch.cat((x, context[count]), dim = 1) + count += 1 + x = block2(x, t) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x, t) + return self.final_conv(x) + +class Flow(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + resnet_block_groups = 8, + ): + super().__init__() + + # determine dimensions + + self.channels = channels + input_channels = channels + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in), + block_klass(dim_in, dim_in), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out), + block_klass(dim_out + dim_in, dim_out), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x): + + x = self.init_conv(x) + r = x.clone() + + h = [] + context = [] + for block1, block2, attn, downsample in self.downs: + x = block1(x) + h.append(x) + context.append(x) + x = block2(x) + x = attn(x) + h.append(x) + context.append(x) + x = downsample(x) + + x = self.mid_block1(x) + x = self.mid_attn(x) + x = self.mid_block2(x) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x) + x = F.tanh(self.final_conv(x)) + + return x, context + #return self.final_conv(x), context + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def linear_beta_schedule(timesteps): + """ + linear schedule, proposed in original ddpm paper + """ + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) + +def cosine_beta_schedule(timesteps, s = 0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps + alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + +def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): + """ + sigmoid schedule + proposed in https://arxiv.org/abs/2212.11972 - Figure 8 + better for images > 64x64, when used during training + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps + v_start = torch.tensor(start / tau).sigmoid() + v_end = torch.tensor(end / tau).sigmoid() + alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + +class GaussianDiffusion(nn.Module): + def __init__( + self, + model, + flow, + *, + image_size, + timesteps = 1200, + sampling_timesteps = None, + loss_type = 'l1', + objective = 'pred_noise', + beta_schedule = 'sigmoid', + schedule_fn_kwargs = dict(), + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1, + ddim_sampling_eta = 0., + auto_normalize = True + ): + super().__init__() + #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) + #assert not model.random_or_learned_sinusoidal_cond + + self.model = model + + self.umodel = context_net(upscale=8, in_chans=1, img_size=48, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, + num_heads=[8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv') + + self.flow = flow + self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) + + self.channels = self.model.channels + self.self_condition = self.model.self_condition + + self.image_size = image_size + + self.objective = objective + + assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' + + if beta_schedule == 'linear': + beta_schedule_fn = linear_beta_schedule + elif beta_schedule == 'cosine': + beta_schedule_fn = cosine_beta_schedule + elif beta_schedule == 'sigmoid': + beta_schedule_fn = sigmoid_beta_schedule + else: + raise ValueError(f'unknown beta schedule {beta_schedule}') + + betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # sampling related parameters + + self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training + + assert self.sampling_timesteps <= timesteps + self.is_ddim_sampling = self.sampling_timesteps < timesteps + self.ddim_sampling_eta = ddim_sampling_eta + + # helper function to register buffer from float64 to float32 + + register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # calculate p2 reweighting + + register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) + + # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False + + self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity + self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_noise_from_start(self, x_t, t, x0): + return ( + (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + ) + + def predict_v(self, x_start, t, noise): + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + + def predict_start_from_v(self, x_t, t, v): + return ( + extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): + def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): + + #model_output = self.model(x, t, x_self_cond) + #print(x.shape, l_cond.shape) + model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) + + maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity + + if self.objective == 'pred_noise': + pred_noise = model_output + x_start = self.predict_start_from_noise(x, t, pred_noise) + x_start = maybe_clip(x_start) + + elif self.objective == 'pred_x0': + x_start = model_output + x_start = maybe_clip(x_start) + pred_noise = self.predict_noise_from_start(x, t, x_start) + + elif self.objective == 'pred_v': + v = model_output + x_start = self.predict_start_from_v(x, t, v) + x_start = maybe_clip(x_start) + pred_noise = self.predict_noise_from_start(x, t, x_start) + + return ModelPrediction(pred_noise, x_start) + + #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): + def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): + + #preds = self.model_predictions(x, t, x_self_cond) + preds = self.model_predictions(x, t, context, x_self_cond) + x_start = preds.pred_x_start + + if clip_denoised: + x_start.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) + return model_mean, posterior_variance, posterior_log_variance, x_start + + @torch.no_grad() + #def p_sample(self, x, t: int, x_self_cond = None): + def p_sample(self, x, t: int, context, x_self_cond = None): + + b, *_, device = *x.shape, x.device + batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) + #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) + noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 + pred_img = model_mean + (0.5 * model_log_variance).exp() * noise + return pred_img, x_start + + @torch.no_grad() + #def p_sample_loop(self, shape, return_all_timesteps = False): + def p_sample_loop(self, shape, context, return_all_timesteps = False): + + batch, device = shape[0], self.betas.device + + img = torch.randn(shape, device = device) + imgs = [img] + + x_start = None + + for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + self_cond = x_start if self.self_condition else None + #img, x_start = self.p_sample(img, t, self_cond) + img, x_start = self.p_sample(img, t, context, self_cond) + imgs.append(img) + + ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) + + #ret = self.unnormalize(ret) + return ret + + @torch.no_grad() + #def ddim_sample(self, shape, return_all_timesteps = False): + def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): + print('here!!!') + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective + + times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps + times = list(reversed(times.int().tolist())) + time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] + + img = torch.randn(shape, device = device) + imgs = [img] + + x_start = None + + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): + time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + self_cond = x_start if self.self_condition else None + #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) + pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) + + imgs.append(img) + + if time_next < 0: + img = x_start + continue + + alpha = self.alphas_cumprod[time] + alpha_next = self.alphas_cumprod[time_next] + + sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c = (1 - alpha_next - sigma ** 2).sqrt() + + noise = torch.randn_like(img) + + img = x_start * alpha_next.sqrt() + \ + c * pred_noise + \ + sigma * noise + + imgs.append(img) + ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) + + #ret = self.unnormalize(ret) + return ret + + @torch.no_grad() + def sample(self, lres, return_all_timesteps = False): + + b, f, c, h, w = lres.shape + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) + + lres = self.normalize(lres) + ures = self.normalize(ures) + sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample + + l = ures.clone() + r = torch.roll(l, -1, 1) + ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + + m = lres.clone() + m1 = rearrange(m, 'b t c h w -> (b t) c h w') + m1 = self.upsample(m1) + m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + m1 = torch.roll(m1, -2, 1) + #m1 = torch.roll(l, -2, 1) + + stack = torch.cat((l, r, m1), 2) + stack = stack[:, :(f-2), :, :, :] + stack = rearrange(stack, 'b t c h w -> (b t) c h w') + + flow, context = self.flow(stack) + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + res = sample_fn((b*(f-2),c,8*h,8*w), l_cond, context, return_all_timesteps = return_all_timesteps) + sres = warped + res + sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) + + warped = rearrange(warped, '(b t) c h w -> b t c h w', b = b) + res = rearrange(res, '(b t) c h w -> b t c h w', b = b) + flow = rearrange(flow, '(b t) c h w -> b t c h w', t = f-2) + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow + + @torch.no_grad() + def interpolate(self, x1, x2, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device = device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + @property + def loss_fn(self): + if self.loss_type == 'l1': + return F.l1_loss + elif self.loss_type == 'l2': + return F.mse_loss + else: + raise ValueError(f'invalid loss type {self.loss_type}') + + def p_losses(self, stack, hres, lres, ures, t, noise = None): + + b, f, c, h, w = hres.shape + + stack = rearrange(stack, 'b t c h w -> (b t) c h w') + ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + + flow, context = self.flow(stack) + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + x_start = x_start - warped + + l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + + b, c, h, w = x_start.shape + + del f + + noise = default(noise, lambda: torch.randn_like(x_start)) + + # noise sample + + x = self.q_sample(x_start = x_start, t = t, noise = noise) + + # if doing self-conditioning, 50% of the time, predict x_start from current set of times + # and condition with unet with that + # this technique will slow down training by 25%, but seems to lower FID significantly + + x_self_cond = None + if self.self_condition and random() < 0.5: + with torch.no_grad(): + x_self_cond = self.model_predictions(x, t).pred_x_start + x_self_cond.detach_() + + # predict and take gradient step + + model_out = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) + + if self.objective == 'pred_noise': + target = noise + elif self.objective == 'pred_x0': + target = x_start + elif self.objective == 'pred_v': + v = self.predict_v(x_start, t, noise) + target = v + else: + raise ValueError(f'unknown objective {self.objective}') + + loss = self.loss_fn(model_out, target, reduction = 'none') + loss = reduce(loss, 'b ... -> b (...)', 'mean') + + loss = loss * extract(self.p2_loss_weight, t, loss.shape) + + loss1 = self.loss_fn(ures, hres, reduction = 'none') + loss1 = reduce(loss1, 'b ... -> b (...)', 'mean') + + loss2 = self.loss_fn(x_start, warped, reduction = 'none') + loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') + + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 + + def forward(self, lres, hres, *args, **kwargs): + + b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + + assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + + t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) + + lres = self.normalize(lres) + hres = self.normalize(hres) + ures = self.normalize(ures) + + l = ures.clone() + r = torch.roll(l, -1, 1) + + m = lres.clone() + m1 = rearrange(m, 'b t c h w -> (b t) c h w') + m1 = self.upsample(m1) + m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) + m1 = torch.roll(m1, -2, 1) + #m1 = torch.roll(l, -2, 1) + + stack = torch.cat((l, r, m1), 2) + stack = stack[:, :(f-2), :, :, :] + + return self.p_losses(stack, hres, lres, ures, t, *args, **kwargs) + +# trainer class + +class Trainer(object): + def __init__( + self, + diffusion_model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + #augment_horizontal_flip = True, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 1, + #num_samples = 25, + eval_folder = './evaluate', + results_folder = './results', + #tensorboard_dir = './tensorboard', + val_num_of_batch = 2, + amp = False, + fp16 = False, + #fp16 = True, + split_batches = True, + #split_batches = False, + convert_image_to = None + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("vsr-orig-autoreg-hres", + init_kwargs={ + "wandb": { + "notes": "Use VSR to improve precipitation forecasting.", + # Change "name" to set the name of the run. + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + + self.model = diffusion_model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + self.image_size = diffusion_model.image_size + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = eval_folder + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = cycle(val_dl) + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + + c384_min = np.load('data/only_precip/c384_min.npy') + c384_max = np.load('data/only_precip/c384_max.npy') + c384_logmin = np.load('data/only_precip/c384_logmin.npy') + + c48_min = np.load('data/only_precip/c48_min.npy') + c48_max = np.load('data/only_precip/c48_max.npy') + c48_logmin = np.load('data/only_precip/c48_logmin.npy') + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + #data = next(self.dl).to(device) + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + with self.accelerator.autocast(): + + #loss = self.model(data) + loss = self.model(lres, hres) + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + #self.writer.add_scalar("loss", total_loss, self.step) + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + num_samples = 5 + num_videos_per_batch = 1 + num_frames = 5 + img_size = 384 + img_channels = 1 + + truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + for k in range(num_samples): + videos, base, res, flows = self.ema.ema_model.sample(lres) + pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) + psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') + + videos_time_mean = videos.mean(dim = 1) + hres_time_mean = hres[:,2:,:,:,:].mean(dim = 1) + bias = videos_time_mean - hres_time_mean + norm = mpl.colors.Normalize(vmin = bias.min(), vmax = bias.max()) + sm = smap(norm, cmap) + b_c = [] + for l in range(num_videos_per_batch): + b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) + bias_color = np.stack(b_c, axis = 0) + + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + target = np.exp(target) + c384_logmin - 1e-14 + + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = np.exp(output) + c384_logmin - 1e-14 + + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + coarse = np.exp(coarse) + c48_logmin - 1e-14 + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + ) + ax1.hist( + target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) + for m in range(num_frames): + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + + accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"psnr": psnr_index.mean()}, step=self.step) + accelerator.log({"crps": crps_index}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + cmap = mpl.colormaps['viridis'] + sm = smap(None, cmap) + + with torch.no_grad(): + + for k, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if k >= self.val_num_of_batch: + break + + limit = lres.shape[1] + if limit < 8: + + #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) + videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") + torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") + torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") + + for i, b in enumerate(videos.clamp(0, 1)): + if not os.path.isdir(os.path.join(self.eval_folder, "generated")): + os.makedirs(os.path.join(self.eval_folder, "generated")) + Parallel(n_jobs=4)( + delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") + for j, f in enumerate(b.cpu()) + ) + + #videos = torch.log(videos.clamp(0.0, 1.0) + 1) + #hres = torch.log(hres + 1) + + #for i, b in enumerate(videos.clamp(0, 1)): + # for i, b in enumerate(videos): + # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): + # os.makedirs(os.path.join(self.eval_folder, "generated")) + # Parallel(n_jobs=4)( + # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") + # for j, f in enumerate(b.cpu()) + # ) + +# for i, b in enumerate(nsteps.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): +# os.makedirs(os.path.join(self.eval_folder, "residual")) +# Parallel(n_jobs=4)( +# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") +# for j, f in enumerate(b.cpu()) +# ) + +# for i, b in enumerate(base.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): +# os.makedirs(os.path.join(self.eval_folder, "warped")) +# Parallel(n_jobs=4)( +# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") +# for j, f in enumerate(b.cpu()) +# ) + +# for i, b in enumerate(flows.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): +# os.makedirs(os.path.join(self.eval_folder, "flows")) +# Parallel(n_jobs=4)( +# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") +# for j, f in enumerate(b.cpu()) +# ) + + for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): + if not os.path.isdir(os.path.join(self.eval_folder, "truth")): + os.makedirs(os.path.join(self.eval_folder, "truth")) + Parallel(n_jobs=4)( + delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") + for j, f in enumerate(b.cpu()) + ) + +# else: + +# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) + +# st = 5 +# ed = st + 7 + +# while ed < limit: + +# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) +# st += 5 +# ed += 5 +# videos = torch.cat((videos, vi), 1) +# #base = torch.cat((base, ba), 1) +# #nsteps = torch.cat((nsteps, ns), 1) +# #flows = torch.cat((flows, fl), 1) + +# for i, b in enumerate(videos.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): +# os.makedirs(os.path.join(self.eval_folder, "generated")) +# Parallel(n_jobs=4)( +# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") +# for j, f in enumerate(b.cpu()) +# ) + +# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): +# os.makedirs(os.path.join(self.eval_folder, "truth")) +# Parallel(n_jobs=4)( +# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") +# for j, f in enumerate(b.cpu()) +# ) + +# # for i, b in enumerate(nsteps.clamp(0, 1)): +# # #for i, b in enumerate(sampled): +# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): +# # os.makedirs(os.path.join(self.eval_folder, "residual")) +# # Parallel(n_jobs=4)( +# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") +# # for j, f in enumerate(b.cpu()) +# # ) + +# # for i, b in enumerate(base.clamp(0, 1)): +# # #for i, b in enumerate(sampled): +# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): +# # os.makedirs(os.path.join(self.eval_folder, "warped")) +# # Parallel(n_jobs=4)( +# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") +# # for j, f in enumerate(b.cpu()) +# # ) + +# # for i, b in enumerate(flows.clamp(0, 1)): +# # #for i, b in enumerate(sampled): +# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): +# # os.makedirs(os.path.join(self.eval_folder, "flows")) +# # Parallel(n_jobs=4)( +# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") +# # for j, f in enumerate(b.cpu()) +# # ) + +# for i, b in enumerate(flows.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): +# os.makedirs(os.path.join(self.eval_folder, "flows_d")) +# Parallel(n_jobs=4)( +# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) +# for j, f in enumerate(b.cpu()) +# ) +# for i, b in enumerate(flows.clamp(0, 1)): +# #for i, b in enumerate(sampled): +# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): +# os.makedirs(os.path.join(self.eval_folder, "flows_s")) +# Parallel(n_jobs=4)( +# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') +# for j, f in enumerate(b.cpu()) +# ) \ No newline at end of file diff --git a/projects/super_res/trainer_mod_flow.py b/projects/super_res/trainer_mod_flow.py new file mode 100755 index 0000000000..ec1971e97b --- /dev/null +++ b/projects/super_res/trainer_mod_flow.py @@ -0,0 +1,66 @@ +import os + +from model.autoreg_diffusion_mod_flow import Unet, Flow, GaussianDiffusion, Trainer +from data.load_data import load_data +from config_mod_flow import config + +def main(): + model = Unet( + dim = config.dim, + channels = 2 * config.data_config["img_channel"], + out_dim = config.data_config["img_channel"], + dim_mults = config.dim_mults, + learned_sinusoidal_cond = config.learned_sinusoidal_cond, + random_fourier_features = config.random_fourier_features, + learned_sinusoidal_dim = config.learned_sinusoidal_dim + ).cuda() + + flow = Flow( + dim = config.dim, + channels = 3 * config.data_config["img_channel"], + out_dim = 2, + dim_mults = config.dim_mults + ).cuda() + + diffusion = GaussianDiffusion( + model, + flow, + image_size = config.data_config["img_size"], + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective + ).cuda() + + train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 4, + ) + + trainer = Trainer( + diffusion, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), + ) + + trainer.train() + + +if __name__ == "__main__": + print(config) + main() From 57dad55871d5af344aed9c150a9b0c3ab6c9150d Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Wed, 2 Aug 2023 18:35:27 +0000 Subject: [PATCH 2/9] rollout code --- projects/super_res/config_infer.py | 21 +- .../super_res/model/autoreg_diffusion_mod.py | 181 +++--------------- projects/super_res/sampler.py | 13 +- 3 files changed, 41 insertions(+), 174 deletions(-) diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index 109103ceec..ecc5c3c22d 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -1,21 +1,21 @@ from ml_collections import config_dict +#batch_size = 4 config = config_dict.ConfigDict() - config.dim = 64 -config.dim_mults = (1, 1, 2, 2, 4, 4) +config.dim_mults = (1, 1, 2, 2, 3, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 6 -config.loss = "l1" +config.sampling_steps = 15 +config.loss = "l2" config.objective = "pred_v" config.lr = 8e-5 config.steps = 5000000 -config.grad_acc = 2 -config.val_num_of_batch = 5 +config.grad_acc = 1 +config.val_num_of_batch = 2 config.save_and_sample_every = 5000 config.ema_decay = 0.995 config.amp = False @@ -26,7 +26,7 @@ config.tensorboard_dir = "./tensorboard" config.milestone = 1 -config.batch_size = 4 +config.batch_size = 1 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, @@ -35,8 +35,9 @@ #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": True + "logscale": True, + "quick": True }) -data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" -model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index b000fd0eff..54131f1691 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -1031,7 +1031,7 @@ def p_losses(self, stack, hres, lres, ures, t, noise = None): loss2 = self.loss_fn(x_start, warped, reduction = 'none') loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') - return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*1.0 + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 def forward(self, lres, hres, *args, **kwargs): @@ -1138,7 +1138,9 @@ def __init__( self.results_folder.mkdir(exist_ok=True, parents=True) - self.eval_folder = eval_folder + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) # step counter state @@ -1362,159 +1364,30 @@ def sample(self): self.ema.ema_model.eval() - cmap = mpl.colormaps['viridis'] - sm = smap(None, cmap) + c384_norm= torch.from_numpy(np.load("data/only_precip/c384_norm.npy")) + c48_norm = torch.from_numpy(np.load("data/only_precip/c48_norm.npy")) with torch.no_grad(): - for k, batch in enumerate(self.val_dl): + for tile in range(6): - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if k >= self.val_num_of_batch: - break - - limit = lres.shape[1] - if limit < 8: - - #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) - videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") - - for i, b in enumerate(videos.clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - os.makedirs(os.path.join(self.eval_folder, "generated")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - - #videos = torch.log(videos.clamp(0.0, 1.0) + 1) - #hres = torch.log(hres + 1) - - #for i, b in enumerate(videos.clamp(0, 1)): - # for i, b in enumerate(videos): - # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - # os.makedirs(os.path.join(self.eval_folder, "generated")) - # Parallel(n_jobs=4)( - # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - # for j, f in enumerate(b.cpu()) - # ) - -# for i, b in enumerate(nsteps.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# os.makedirs(os.path.join(self.eval_folder, "residual")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(base.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# os.makedirs(os.path.join(self.eval_folder, "warped")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# os.makedirs(os.path.join(self.eval_folder, "flows")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - - for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "truth")): - os.makedirs(os.path.join(self.eval_folder, "truth")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - -# else: - -# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) - -# st = 5 -# ed = st + 7 - -# while ed < limit: - -# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) -# st += 5 -# ed += 5 -# videos = torch.cat((videos, vi), 1) -# #base = torch.cat((base, ba), 1) -# #nsteps = torch.cat((nsteps, ns), 1) -# #flows = torch.cat((flows, fl), 1) - -# for i, b in enumerate(videos.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): -# os.makedirs(os.path.join(self.eval_folder, "generated")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): -# os.makedirs(os.path.join(self.eval_folder, "truth")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# # for i, b in enumerate(nsteps.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# # os.makedirs(os.path.join(self.eval_folder, "residual")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(base.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# # os.makedirs(os.path.join(self.eval_folder, "warped")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(flows.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# # os.makedirs(os.path.join(self.eval_folder, "flows")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): -# os.makedirs(os.path.join(self.eval_folder, "flows_d")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) -# for j, f in enumerate(b.cpu()) -# ) -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): -# os.makedirs(os.path.join(self.eval_folder, "flows_s")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') -# for j, f in enumerate(b.cpu()) -# ) \ No newline at end of file + st = 0 + en = 27 + count = 0 + + while en < c48_norm.shape[1]: + + print(tile, st) + + lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + + st += 25 + en += 25 + count += 1 \ No newline at end of file diff --git a/projects/super_res/sampler.py b/projects/super_res/sampler.py index 874a9e4fcb..08ecb72a63 100644 --- a/projects/super_res/sampler.py +++ b/projects/super_res/sampler.py @@ -1,6 +1,6 @@ import os -from model.autoreg_diffusion import Unet, Flow, GaussianDiffusion, Trainer +from model.autoreg_diffusion_mod import Unet, Flow, GaussianDiffusion, Trainer from data.load_data import load_data from config_infer import config @@ -31,17 +31,10 @@ objective = config.objective ).cuda() -train_dl, val_dl = load_data( - config.data_config, - config.batch_size, - pin_memory = True, - num_workers = 4, - ) - trainer = Trainer( diffusion, - train_dl, - val_dl, + None, + None, train_batch_size = config.batch_size, train_lr = config.lr, train_num_steps = config.steps, From 0e8c6af4e89ec2a352062ceddb485e6259395d4f Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Mon, 7 Aug 2023 06:28:22 +0000 Subject: [PATCH 3/9] logscale support for quick load, appropriate changes in val loop --- projects/super_res/config.py | 4 +- projects/super_res/config_mod_flow.py | 4 +- projects/super_res/data/vsrdata.py | 9 +- .../super_res/model/autoreg_diffusion_mod.py | 67 +++-- .../model/autoreg_diffusion_mod_flow.py | 239 +++++------------- 5 files changed, 128 insertions(+), 195 deletions(-) diff --git a/projects/super_res/config.py b/projects/super_res/config.py index a70ccced96..fd2719ab82 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -20,7 +20,7 @@ config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "" +config.additional_note = "no-logscale" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" @@ -35,7 +35,7 @@ #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": True, + "logscale": False, "quick": True }) diff --git a/projects/super_res/config_mod_flow.py b/projects/super_res/config_mod_flow.py index daa24191d1..b5df4624f5 100644 --- a/projects/super_res/config_mod_flow.py +++ b/projects/super_res/config_mod_flow.py @@ -20,7 +20,7 @@ config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "mod_flow" +config.additional_note = "mod_flow_no_logscale" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" @@ -35,7 +35,7 @@ #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": True, + "logscale": False, "quick": True }) diff --git a/projects/super_res/data/vsrdata.py b/projects/super_res/data/vsrdata.py index 1dacbfc851..7944dd2ce3 100644 --- a/projects/super_res/data/vsrdata.py +++ b/projects/super_res/data/vsrdata.py @@ -56,9 +56,12 @@ def __init__(self, channels, mode, length, logscale = False, quick = True): self.y = c384_norm[:, split:, :, :, :] else: - - c384_norm= np.load("data/only_precip/c384_norm.npy") - c48_norm = np.load("data/only_precip/c48_norm.npy") + if logscale: + c384_norm= np.load("data/only_precip/c384_lgnorm.npy") + c48_norm = np.load("data/only_precip/c48_lgnorm.npy") + else: + c384_norm= np.load("data/only_precip/c384_norm.npy") + c48_norm = np.load("data/only_precip/c48_norm.npy") # calculate split (80/20) split = int(c384_norm.shape[1] * 0.8) diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index 54131f1691..e6bb3d905f 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -4,7 +4,6 @@ from random import random from functools import partial from collections import namedtuple -from joblib import Parallel, delayed import numpy as np @@ -1194,13 +1193,20 @@ def train(self): cmap = mpl.colormaps['RdBu_r'] fcmap = mpl.colormaps['gray_r'] + + c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + c48_gmin = np.load('data/only_precip/c48_gmin.npy') + c384_min = np.load('data/only_precip/c384_min.npy') c384_max = np.load('data/only_precip/c384_max.npy') - c384_logmin = np.load('data/only_precip/c384_logmin.npy') c48_min = np.load('data/only_precip/c48_min.npy') c48_max = np.load('data/only_precip/c48_max.npy') - c48_logmin = np.load('data/only_precip/c48_logmin.npy') with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: @@ -1264,6 +1270,7 @@ def train(self): truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + for k in range(num_samples): videos, base, res, flows = self.ema.ema_model.sample(lres) pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) @@ -1281,14 +1288,20 @@ def train(self): b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) bias_color = np.stack(b_c, axis = 0) - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min - target = np.exp(target) + c384_logmin - 1e-14 - - output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min - output = np.exp(output) + c384_logmin - 1e-14 - - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min - coarse = np.exp(coarse) + c48_logmin - 1e-14 + if not self.config.data_config.logscale: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.config.data_config.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) diff_output = (output - nn_upscale).flatten() @@ -1336,13 +1349,31 @@ def train(self): for m in range(num_frames): flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) + if self.config.data_config.logscale: + accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) + + else: + accelerator.log({"true_high": wandb.Video((hres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) diff --git a/projects/super_res/model/autoreg_diffusion_mod_flow.py b/projects/super_res/model/autoreg_diffusion_mod_flow.py index 3a7e24a663..ff6f2764ef 100644 --- a/projects/super_res/model/autoreg_diffusion_mod_flow.py +++ b/projects/super_res/model/autoreg_diffusion_mod_flow.py @@ -4,7 +4,6 @@ from random import random from functools import partial from collections import namedtuple -from joblib import Parallel, delayed import numpy as np @@ -15,7 +14,6 @@ import piq -from kornia import filters from torch.optim import Adam from einops import rearrange, reduce @@ -1169,13 +1167,19 @@ def train(self): cmap = mpl.colormaps['RdBu_r'] + c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + c48_gmin = np.load('data/only_precip/c48_gmin.npy') + c384_min = np.load('data/only_precip/c384_min.npy') c384_max = np.load('data/only_precip/c384_max.npy') - c384_logmin = np.load('data/only_precip/c384_logmin.npy') c48_min = np.load('data/only_precip/c48_min.npy') c48_max = np.load('data/only_precip/c48_max.npy') - c48_logmin = np.load('data/only_precip/c48_logmin.npy') with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: @@ -1239,6 +1243,7 @@ def train(self): truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + for k in range(num_samples): videos, base, res, flows = self.ema.ema_model.sample(lres) pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) @@ -1256,14 +1261,20 @@ def train(self): b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) bias_color = np.stack(b_c, axis = 0) - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min - target = np.exp(target) + c384_logmin - 1e-14 - - output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min - output = np.exp(output) + c384_logmin - 1e-14 - - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min - coarse = np.exp(coarse) + c48_logmin - 1e-14 + if not self.config.data_config.logscale: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.config.data_config.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) diff_output = (output - nn_upscale).flatten() @@ -1306,12 +1317,29 @@ def train(self): for m in range(num_frames): flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + if self.config.data_config.logscale: + accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + + else: + accelerator.log({"true_high": wandb.Video((hres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) @@ -1333,159 +1361,30 @@ def sample(self): self.ema.ema_model.eval() - cmap = mpl.colormaps['viridis'] - sm = smap(None, cmap) + c384_norm= torch.from_numpy(np.load("data/only_precip/c384_norm.npy")) + c48_norm = torch.from_numpy(np.load("data/only_precip/c48_norm.npy")) with torch.no_grad(): - for k, batch in enumerate(self.val_dl): + for tile in range(6): - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if k >= self.val_num_of_batch: - break - - limit = lres.shape[1] - if limit < 8: - - #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) - videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") - - for i, b in enumerate(videos.clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - os.makedirs(os.path.join(self.eval_folder, "generated")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - - #videos = torch.log(videos.clamp(0.0, 1.0) + 1) - #hres = torch.log(hres + 1) - - #for i, b in enumerate(videos.clamp(0, 1)): - # for i, b in enumerate(videos): - # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - # os.makedirs(os.path.join(self.eval_folder, "generated")) - # Parallel(n_jobs=4)( - # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - # for j, f in enumerate(b.cpu()) - # ) - -# for i, b in enumerate(nsteps.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# os.makedirs(os.path.join(self.eval_folder, "residual")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(base.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# os.makedirs(os.path.join(self.eval_folder, "warped")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# os.makedirs(os.path.join(self.eval_folder, "flows")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - - for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "truth")): - os.makedirs(os.path.join(self.eval_folder, "truth")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - -# else: - -# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) - -# st = 5 -# ed = st + 7 - -# while ed < limit: - -# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) -# st += 5 -# ed += 5 -# videos = torch.cat((videos, vi), 1) -# #base = torch.cat((base, ba), 1) -# #nsteps = torch.cat((nsteps, ns), 1) -# #flows = torch.cat((flows, fl), 1) - -# for i, b in enumerate(videos.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): -# os.makedirs(os.path.join(self.eval_folder, "generated")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): -# os.makedirs(os.path.join(self.eval_folder, "truth")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# # for i, b in enumerate(nsteps.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# # os.makedirs(os.path.join(self.eval_folder, "residual")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(base.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# # os.makedirs(os.path.join(self.eval_folder, "warped")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(flows.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# # os.makedirs(os.path.join(self.eval_folder, "flows")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): -# os.makedirs(os.path.join(self.eval_folder, "flows_d")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) -# for j, f in enumerate(b.cpu()) -# ) -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): -# os.makedirs(os.path.join(self.eval_folder, "flows_s")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') -# for j, f in enumerate(b.cpu()) -# ) \ No newline at end of file + st = 0 + en = 27 + count = 0 + + while en < c48_norm.shape[1]: + + print(tile, st) + + lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + + st += 25 + en += 25 + count += 1 \ No newline at end of file From 40989d3671bf4435933f7f0a495675bb25fff9a6 Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Fri, 11 Aug 2023 23:21:23 +0000 Subject: [PATCH 4/9] refactor : 2d, 3d flow in one file, switches for logscale, multichannel, minipatch, partial rollout --- projects/super_res/config.py | 17 +- projects/super_res/config_infer.py | 17 +- projects/super_res/config_mod_flow.py | 43 - projects/super_res/data/load_dataset.py | 7 +- projects/super_res/data/vsrdata.py | 81 +- projects/super_res/model/autoreg_diffusion.py | 1409 ----------------- .../super_res/model/autoreg_diffusion_mod.py | 314 +++- .../model/autoreg_diffusion_mod_flow.py | 1390 ---------------- projects/super_res/model/network_swinir.py | 5 +- projects/super_res/sampler.py | 126 +- projects/super_res/trainer.py | 33 +- projects/super_res/trainer_mod_flow.py | 66 - 12 files changed, 393 insertions(+), 3115 deletions(-) delete mode 100644 projects/super_res/config_mod_flow.py delete mode 100644 projects/super_res/model/autoreg_diffusion.py delete mode 100644 projects/super_res/model/autoreg_diffusion_mod_flow.py delete mode 100755 projects/super_res/trainer_mod_flow.py diff --git a/projects/super_res/config.py b/projects/super_res/config.py index fd2719ab82..a3328f2125 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -1,10 +1,9 @@ from ml_collections import config_dict -#batch_size = 4 config = config_dict.ConfigDict() -config.dim = 64 -config.dim_mults = (1, 1, 2, 2, 3, 4) +config.dim = 128 +config.dim_mults = (1, 2, 2, 2, 4, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 @@ -20,23 +19,25 @@ config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "no-logscale" +config.additional_note = "multichannel_minipatch" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" config.milestone = 1 +config.rollout = None +config.rollout_batch = None config.batch_size = 1 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, - #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], "channels": ["PRATEsfc_coarse"], - #"img_channel": 2, "img_channel": 1, "img_size": 384, - "logscale": False, - "quick": True + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False }) config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index ecc5c3c22d..849ddc8a4f 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -1,15 +1,14 @@ from ml_collections import config_dict -#batch_size = 4 config = config_dict.ConfigDict() -config.dim = 64 -config.dim_mults = (1, 1, 2, 2, 3, 4) +config.dim = 128 +config.dim_mults = (1, 2, 2, 2, 4, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 15 +config.sampling_steps = 20 config.loss = "l2" config.objective = "pred_v" config.lr = 8e-5 @@ -20,23 +19,25 @@ config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "" +config.additional_note = "multichannel_minipatch" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" config.milestone = 1 +config.rollout = "partial" +config.rollout_batch = 25 config.batch_size = 1 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, - #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], "channels": ["PRATEsfc_coarse"], - #"img_channel": 2, "img_channel": 1, "img_size": 384, "logscale": True, - "quick": True + "multi": True, + "flow": "2d", + "minipatch": False }) config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" diff --git a/projects/super_res/config_mod_flow.py b/projects/super_res/config_mod_flow.py deleted file mode 100644 index b5df4624f5..0000000000 --- a/projects/super_res/config_mod_flow.py +++ /dev/null @@ -1,43 +0,0 @@ -from ml_collections import config_dict - -#batch_size = 4 -config = config_dict.ConfigDict() - -config.dim = 64 -config.dim_mults = (1, 1, 2, 2, 3, 4) -config.learned_sinusoidal_cond = True, -config.random_fourier_features = True, -config.learned_sinusoidal_dim = 32 -config.diffusion_steps = 1500 -config.sampling_steps = 20 -config.loss = "l2" -config.objective = "pred_v" -config.lr = 8e-5 -config.steps = 5000000 -config.grad_acc = 1 -config.val_num_of_batch = 2 -config.save_and_sample_every = 5000 -config.ema_decay = 0.995 -config.amp = False -config.split_batches = True -config.additional_note = "mod_flow_no_logscale" -config.eval_folder = "./evaluate" -config.results_folder = "./results" -config.tensorboard_dir = "./tensorboard" -config.milestone = 1 - -config.batch_size = 1 -config.data_config = config_dict.ConfigDict({ - "dataset_name": "c384", - "length": 7, - #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], - "channels": ["PRATEsfc_coarse"], - #"img_channel": 2, - "img_channel": 1, - "img_size": 384, - "logscale": False, - "quick": True -}) - -config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" -config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/data/load_dataset.py b/projects/super_res/data/load_dataset.py index 6678d27bdd..e3246963f9 100644 --- a/projects/super_res/data/load_dataset.py +++ b/projects/super_res/data/load_dataset.py @@ -2,14 +2,13 @@ def load_dataset(data_config): - channels = data_config["channels"] length = data_config["length"] logscale = data_config["logscale"] - quick = data_config["quick"] + multi = data_config["multi"] train, val = None, None - train = VSRDataset(channels, 'train', length, logscale, quick) - val = VSRDataset(channels, 'val', length, logscale, quick) + train = VSRDataset('train', length, logscale, multi) + val = VSRDataset('val', length, logscale, multi) return train, val \ No newline at end of file diff --git a/projects/super_res/data/vsrdata.py b/projects/super_res/data/vsrdata.py index 7944dd2ce3..9a90a4153d 100644 --- a/projects/super_res/data/vsrdata.py +++ b/projects/super_res/data/vsrdata.py @@ -1,10 +1,9 @@ -import xarray as xr import numpy as np from torch.utils.data import Dataset class VSRDataset(Dataset): - def __init__(self, channels, mode, length, logscale = False, quick = True): + def __init__(self, mode, length, logscale = False, multi = False): ''' Args: channels (list): list of channels to use @@ -20,61 +19,43 @@ def __init__(self, channels, mode, length, logscale = False, quick = True): # mode self.mode = mode - if not quick: - # load data from bucket - # shape : (tile, time, y, x) - c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) - c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) - - # convert to numpy - # shape : (tile, time, channel, y, x) - c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) - c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + # data shape : (num_tiles, num_frames, num_channels, height, width) + # num_tiles = 6; num_frames = 2920, num_channels = 1 + if logscale: - if logscale: - c384_np = np.log(c384_np - c384_np.min() + 1e-14) - c48_np = np.log(c48_np - c48_np.min() + 1e-14) + c384_norm= np.load("data/only_precip/c384_lgnorm.npy") + c48_norm = np.load("data/only_precip/c48_lgnorm.npy") - # calculate split (80/20) - split = int(c384_np.shape[1] * 0.8) + else: - # compute statistics on training set - c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + c384_norm= np.load("data/only_precip/c384_norm.npy") + c48_norm = np.load("data/only_precip/c48_norm.npy") + + t, f, c, h, w = c384_norm.shape - # normalize - c384_norm= (c384_np - c384_min) / (c384_max - c384_min) - c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + if multi: - if mode == 'train': - - self.X = c48_norm[:, :split, :, :, :] - self.y = c384_norm[:, :split, :, :, :] - - elif mode == 'val': - - self.X = c48_norm[:, split:, :, :, :] - self.y = c384_norm[:, split:, :, :, :] + # load more channels, order : ("UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse") + c48_norm_more = np.load("data/more_channels/c48_norm.npy") + c48_norm = np.concatenate((c48_norm, c48_norm_more), axis = 2) - else: - if logscale: - c384_norm= np.load("data/only_precip/c384_lgnorm.npy") - c48_norm = np.load("data/only_precip/c48_lgnorm.npy") - else: - c384_norm= np.load("data/only_precip/c384_norm.npy") - c48_norm = np.load("data/only_precip/c48_norm.npy") + # load topography, shape : (num_tiles, height, width) + # reshaping to match data shape + topo384 = np.repeat(np.load("data/topography/topo384_norm.npy").reshape((t, 1, c, 384, 384)), f, axis = 1) + c384_norm = np.concatenate((c384_norm, topo384), axis = 2) - # calculate split (80/20) - split = int(c384_norm.shape[1] * 0.8) + # calculate split (80/20) + split = int(c384_norm.shape[1] * 0.8) - if mode == 'train': - - self.X = c48_norm[:, :split, :, :, :] - self.y = c384_norm[:, :split, :, :, :] - - elif mode == 'val': - - self.X = c48_norm[:, split:, :, :, :] - self.y = c384_norm[:, split:, :, :, :] + if mode == 'train': + + self.X = c48_norm[:, :split, :, :, :] + self.y = c384_norm[:, :split, :, :, :] + + elif mode == 'val': + + self.X = c48_norm[:, split:, :, :, :] + self.y = c384_norm[:, split:, :, :, :] def __len__(self): @@ -83,13 +64,13 @@ def __len__(self): def __getitem__(self, idx): # load a random tile index - if self.mode == 'train': tile = np.random.randint(0, self.X.shape[0]) elif self.mode == 'val': tile = 0 + # tensor shape : (length, num_channels, height, width) lowres = self.X[tile, idx:idx+self.length, :, :, :] highres = self.y[tile, idx:idx+self.length, :, :, :] diff --git a/projects/super_res/model/autoreg_diffusion.py b/projects/super_res/model/autoreg_diffusion.py deleted file mode 100644 index 3d8eabc370..0000000000 --- a/projects/super_res/model/autoreg_diffusion.py +++ /dev/null @@ -1,1409 +0,0 @@ -import os -import math -from pathlib import Path -from random import random -from functools import partial -from collections import namedtuple -from joblib import Parallel, delayed - -import numpy as np - -import torch -from torch import nn, einsum -import torch.nn.functional as F -import wandb - -import piq - -from kornia import filters -from torch.optim import Adam - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange - -from PIL import Image - -import matplotlib as mpl -from matplotlib.cm import ScalarMappable as smap - -from tqdm.auto import tqdm -from ema_pytorch import EMA - -import flow_vis - -from accelerate import Accelerator - -# constants - -ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) - -# helpers functions - -def save_image(tensor, path): - im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) - im.save(path) - return None - -def exists(x): - return x is not None - -def default(val, d): - if exists(val): - return val - return d() if callable(d) else d - -def identity(t, *args, **kwargs): - return t - -def cycle(dl): - while True: - for data in dl: - yield data - -def has_int_squareroot(num): - return (math.sqrt(num) ** 2) == num - -def num_to_groups(num, divisor): - groups = num // divisor - remainder = num % divisor - arr = [divisor] * groups - if remainder > 0: - arr.append(remainder) - return arr - -def convert_image_to_fn(img_type, image): - if image.mode != img_type: - return image.convert(img_type) - return image - -# normalization functions - -def normalize_to_neg_one_to_one(img): - return img * 2 - 1 - -def unnormalize_to_zero_to_one(t): - return (t + 1) * 0.5 - -# ssf modules - -def gaussian_pyramids(input, base_sigma = 1, m = 5): - - output = [input] - N, C, H, W = input.shape - kernel = filters.get_gaussian_kernel2d((5, 5), (base_sigma, base_sigma)) - - for i in range(m): - - input = filters.filter2d(input, kernel) - - if i == 0: - - output.append(input) - - else: - - tmp = input - - for j in range(i): - - tmp = F.interpolate(tmp, scale_factor = 2., mode = 'bilinear', align_corners = True) - - output.append(tmp) - - input = F.interpolate(input, scale_factor = 0.5) - - return torch.stack(output, 2) - -def scale_space_warp(input, flow): - - N, C, H, W = input.shape - - assert flow.shape == (N, 3, H, W) - - flow = flow.unsqueeze(0) - #multi_scale = gaussian_pyramids(input, self.base_scale, self.gaussian_dim) - multi_scale = gaussian_pyramids(input, 1.0, 5) - - h = torch.arange(H, device=input.device, dtype=input.dtype) - w = torch.arange(W, device=input.device, dtype=input.dtype) - d = torch.zeros(1, device=input.device, dtype=input.dtype) - - grid = torch.stack(torch.meshgrid(d, h, w)[::-1], -1).unsqueeze(0) - grid = grid.expand(N, -1, -1, -1, -1) - flow = flow.permute(1, 0, 3, 4, 2) # N, 1, H, W, 3 - - # reparameterization - # var_channel = (flow[..., -1].exp())**2 - # var_space = [0.] + [(2.**i * self.base_scale)**2 for i in range(self.gaussian_dim)] - # d_offset = var_to_position(var_channel, var_space).unsqueeze(-1) - d_offset = flow[..., -1].clamp(min=-1.0, max=1.0).unsqueeze(-1) - - flow = torch.cat((flow[..., :2], d_offset), -1) - flow_grid = flow + grid - flow_grid[..., 0] = 2.0 * flow_grid[..., 0] / max(W - 1.0, 1.0) - 1.0 - flow_grid[..., 1] = 2.0 * flow_grid[..., 1] / max(H - 1.0, 1.0) - 1.0 - - warped = F.grid_sample(multi_scale, flow_grid, padding_mode = "border", align_corners = True).squeeze(2) - - return warped - -# small helper modules - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x, *args, **kwargs): - return self.fn(x, *args, **kwargs) + x - -def Upsample(dim, dim_out = None): - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) - ) - -def Downsample(dim, dim_out = None): - return nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), - nn.Conv2d(dim * 4, default(dim_out, dim), 1) - ) - -class WeightStandardizedConv2d(nn.Conv2d): - """ - https://arxiv.org/abs/1903.10520 - weight standardization purportedly works synergistically with group normalization - """ - def forward(self, x): - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - - weight = self.weight - mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') - var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) - normalized_weight = (weight - mean) * (var + eps).rsqrt() - - return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - -class LayerNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - - def forward(self, x): - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) * (var + eps).rsqrt() * self.g - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.fn = fn - self.norm = LayerNorm(dim) - - def forward(self, x): - x = self.norm(x) - return self.fn(x) - -# sinusoidal positional embeds - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - -class RandomOrLearnedSinusoidalPosEmb(nn.Module): - """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ - """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ - - def __init__(self, dim, is_random = False): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) - - def forward(self, x): - x = rearrange(x, 'b -> b 1') - freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) - return fouriered - -# building block modules - -class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): - super().__init__() - self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) - self.act = nn.SiLU() - - def forward(self, x, scale_shift = None): - x = self.proj(x) - x = self.norm(x) - - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - - x = self.act(x) - return x - -class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): - super().__init__() - self.mlp = nn.Sequential( - nn.SiLU(), - nn.Linear(time_emb_dim, dim_out * 2) - ) if exists(time_emb_dim) else None - - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) - self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - - def forward(self, x, time_emb = None): - - scale_shift = None - if exists(self.mlp) and exists(time_emb): - time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) - - h = self.block1(x, scale_shift = scale_shift) - - h = self.block2(h) - - return h + self.res_conv(x) - -class LinearAttention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): - super().__init__() - self.scale = dim_head ** -0.5 - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - - self.to_out = nn.Sequential( - nn.Conv2d(hidden_dim, dim, 1), - LayerNorm(dim) - ) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) - - q = q.softmax(dim = -2) - k = k.softmax(dim = -1) - - q = q * self.scale - v = v / (h * w) - - context = torch.einsum('b h d n, b h e n -> b h d e', k, v) - - out = torch.einsum('b h d e, b h d n -> b h e n', context, q) - out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) - return self.to_out(out) - -class Attention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): - super().__init__() - self.scale = dim_head ** -0.5 - self.heads = heads - hidden_dim = dim_head * heads - - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) - - q = q * self.scale - - sim = einsum('b h d i, b h d j -> b h i j', q, k) - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h d j -> b h i d', attn, v) - - out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) - return self.to_out(out) - -# model - -class Unet(nn.Module): - def __init__( - self, - dim, - init_dim = None, - out_dim = None, - dim_mults=(1, 2, 4, 8), - channels = 3, - self_condition = False, - resnet_block_groups = 8, - learned_variance = False, - learned_sinusoidal_cond = False, - random_fourier_features = False, - learned_sinusoidal_dim = 16 - ): - super().__init__() - - # determine dimensions - - self.channels = channels - self.self_condition = self_condition - input_channels = channels * (2 if self_condition else 1) - - init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - - # time embeddings - - time_dim = dim * 4 - - self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features - - if self.random_or_learned_sinusoidal_cond: - sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) - fourier_dim = learned_sinusoidal_dim + 1 - else: - sinu_pos_emb = SinusoidalPosEmb(dim) - fourier_dim = dim - - self.time_mlp = nn.Sequential( - sinu_pos_emb, - nn.Linear(fourier_dim, time_dim), - nn.GELU(), - nn.Linear(time_dim, time_dim) - ) - - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.downs.append(nn.ModuleList([ - block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), - block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - default_out_dim = channels * (1 if not learned_variance else 2) - self.out_dim = default(out_dim, default_out_dim) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) - self.final_conv = nn.Conv2d(dim, self.out_dim, 1) - - def forward(self, x, time, context, x_self_cond = None): - if self.self_condition: - x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - x = torch.cat((x_self_cond, x), dim = 1) - - x = self.init_conv(x) - r = x.clone() - - t = self.time_mlp(time) - - h = [] - - count = 0 - - for block1, block2, attn, downsample in self.downs: - x = torch.cat((x, context[count]), dim = 1) - count += 1 - x = block1(x, t) - h.append(x) - - x = torch.cat((x, context[count]), dim = 1) - count += 1 - x = block2(x, t) - x = attn(x) - h.append(x) - - x = downsample(x) - - x = self.mid_block1(x, t) - x = self.mid_attn(x) - x = self.mid_block2(x, t) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) - x = block1(x, t) - - x = torch.cat((x, h.pop()), dim = 1) - x = block2(x, t) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim = 1) - - x = self.final_res_block(x, t) - return self.final_conv(x) - -class Flow(nn.Module): - def __init__( - self, - dim, - init_dim = None, - out_dim = None, - dim_mults=(1, 2, 4, 8), - channels = 3, - resnet_block_groups = 8, - ): - super().__init__() - - # determine dimensions - - self.channels = channels - input_channels = channels - - init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in), - block_klass(dim_in, dim_in), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out), - block_klass(dim_out + dim_in, dim_out), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - default_out_dim = channels - self.out_dim = default(out_dim, default_out_dim) - - self.final_res_block = block_klass(dim * 2, dim) - self.final_conv = nn.Conv2d(dim, self.out_dim, 1) - - def forward(self, x): - - x = self.init_conv(x) - r = x.clone() - - h = [] - context = [] - for block1, block2, attn, downsample in self.downs: - x = block1(x) - h.append(x) - context.append(x) - x = block2(x) - x = attn(x) - h.append(x) - context.append(x) - x = downsample(x) - - x = self.mid_block1(x) - x = self.mid_attn(x) - x = self.mid_block2(x) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) - x = block1(x) - - x = torch.cat((x, h.pop()), dim = 1) - x = block2(x) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim = 1) - - x = self.final_res_block(x) - return self.final_conv(x), context - -# gaussian diffusion trainer class - -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - -def linear_beta_schedule(timesteps): - """ - linear schedule, proposed in original ddpm paper - """ - scale = 1000 / timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) - -def cosine_beta_schedule(timesteps, s = 0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps - alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - -def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): - """ - sigmoid schedule - proposed in https://arxiv.org/abs/2212.11972 - Figure 8 - better for images > 64x64, when used during training - """ - steps = timesteps + 1 - t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps - v_start = torch.tensor(start / tau).sigmoid() - v_end = torch.tensor(end / tau).sigmoid() - alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - -class GaussianDiffusion(nn.Module): - def __init__( - self, - model, - flow, - *, - image_size, - timesteps = 1200, - sampling_timesteps = None, - loss_type = 'l1', - objective = 'pred_noise', - beta_schedule = 'sigmoid', - schedule_fn_kwargs = dict(), - p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended - p2_loss_weight_k = 1, - ddim_sampling_eta = 0., - auto_normalize = True - ): - super().__init__() - #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) - #assert not model.random_or_learned_sinusoidal_cond - - self.model = model - - self.flow = flow - self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) - - self.channels = self.model.channels - self.self_condition = self.model.self_condition - - self.image_size = image_size - - self.objective = objective - - assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' - - if beta_schedule == 'linear': - beta_schedule_fn = linear_beta_schedule - elif beta_schedule == 'cosine': - beta_schedule_fn = cosine_beta_schedule - elif beta_schedule == 'sigmoid': - beta_schedule_fn = sigmoid_beta_schedule - else: - raise ValueError(f'unknown beta schedule {beta_schedule}') - - betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) - - alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - # sampling related parameters - - self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training - - assert self.sampling_timesteps <= timesteps - self.is_ddim_sampling = self.sampling_timesteps < timesteps - self.ddim_sampling_eta = ddim_sampling_eta - - # helper function to register buffer from float64 to float32 - - register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) - - register_buffer('betas', betas) - register_buffer('alphas_cumprod', alphas_cumprod) - register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) - register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) - register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) - register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) - register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - register_buffer('posterior_variance', posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) - register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) - register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) - - # calculate p2 reweighting - - register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) - - # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False - - self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity - self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def predict_noise_from_start(self, x_t, t, x0): - return ( - (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - ) - - def predict_v(self, x_start, t, noise): - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start - ) - - def predict_start_from_v(self, x_t, t, v): - return ( - extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): - def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): - - #model_output = self.model(x, t, x_self_cond) - #print(x.shape, l_cond.shape) - model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) - - maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity - - if self.objective == 'pred_noise': - pred_noise = model_output - x_start = self.predict_start_from_noise(x, t, pred_noise) - x_start = maybe_clip(x_start) - - elif self.objective == 'pred_x0': - x_start = model_output - x_start = maybe_clip(x_start) - pred_noise = self.predict_noise_from_start(x, t, x_start) - - elif self.objective == 'pred_v': - v = model_output - x_start = self.predict_start_from_v(x, t, v) - x_start = maybe_clip(x_start) - pred_noise = self.predict_noise_from_start(x, t, x_start) - - return ModelPrediction(pred_noise, x_start) - - #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): - def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): - - #preds = self.model_predictions(x, t, x_self_cond) - preds = self.model_predictions(x, t, context, x_self_cond) - x_start = preds.pred_x_start - - if clip_denoised: - x_start.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) - return model_mean, posterior_variance, posterior_log_variance, x_start - - @torch.no_grad() - #def p_sample(self, x, t: int, x_self_cond = None): - def p_sample(self, x, t: int, context, x_self_cond = None): - - b, *_, device = *x.shape, x.device - batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) - #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) - model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) - noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 - pred_img = model_mean + (0.5 * model_log_variance).exp() * noise - return pred_img, x_start - - @torch.no_grad() - #def p_sample_loop(self, shape, return_all_timesteps = False): - def p_sample_loop(self, shape, context, return_all_timesteps = False): - - batch, device = shape[0], self.betas.device - - img = torch.randn(shape, device = device) - imgs = [img] - - x_start = None - - for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): - self_cond = x_start if self.self_condition else None - #img, x_start = self.p_sample(img, t, self_cond) - img, x_start = self.p_sample(img, t, context, self_cond) - imgs.append(img) - - ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - - #ret = self.unnormalize(ret) - return ret - - @torch.no_grad() - #def ddim_sample(self, shape, return_all_timesteps = False): - def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): - - batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective - - times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps - times = list(reversed(times.int().tolist())) - time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] - - img = torch.randn(shape, device = device) - imgs = [img] - - x_start = None - - for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): - time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - self_cond = x_start if self.self_condition else None - #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) - pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) - - imgs.append(img) - - if time_next < 0: - img = x_start - continue - - alpha = self.alphas_cumprod[time] - alpha_next = self.alphas_cumprod[time_next] - - sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() - c = (1 - alpha_next - sigma ** 2).sqrt() - - noise = torch.randn_like(img) - - img = x_start * alpha_next.sqrt() + \ - c * pred_noise + \ - sigma * noise - - imgs.append(img) - ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - - #ret = self.unnormalize(ret) - return ret - - @torch.no_grad() - #def sample(self, batch_size = 16, return_all_timesteps = False): - def sample(self, lres, hres, return_all_timesteps = False): - - b, f, c, h, w, image_size, channels = *hres.shape, self.image_size, self.channels - print(b,f,c,h,w) - lres = self.normalize(lres) - hres = self.normalize(hres) - sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample - - l = hres.clone()[:, :1, :, :, :] - r = hres.clone()[:, 1:2, :, :, :] - hres_flow = rearrange(hres[:, 1:2, :, :, :], 'b t c h w -> (b t) c h w') - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - - m = lres.clone() - m1 = rearrange(m, 'b t c h w -> (b t) c h w') - m1 = self.upsample(m1) - m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) - m1 = torch.roll(m1, -2, 1) - m1 = m1[:, :(f-2), :, :, :] - - ans = [] - base = [] - nsteps = [] - flows = [] - - for i in range(f-2): - - stack = torch.cat((l, r, m1[:, i:i+1, :, :, :]), 2) - - stack = rearrange(stack, 'b t c h w -> (b t) c h w') - - flow, context = self.flow(stack) - - warped = scale_space_warp(hres_flow, flow) - batch_size = b - #res = sample_fn((batch_size, c, image_size, image_size), l_cond[i::(f-2), :, :, :], context, return_all_timesteps = return_all_timesteps) - - res = sample_fn((batch_size, c, h, w), l_cond[i::(f-2), :, :, :], context, return_all_timesteps = return_all_timesteps) - hres_flow = warped + res - - #hres_flow = warped + res[:, -1, :, :, :] - - l = r - r = rearrange(hres_flow, '(b t) c h w -> b t c h w', t = 1) - - ans.append(hres_flow) - base.append(warped) - #nsteps.append(torch.cat(torch.unbind(res, 1), 3)) - nsteps.append(res) - flows.append(flow) - - return self.unnormalize(torch.stack(ans, 1)), self.unnormalize(torch.stack(base, 1)), self.unnormalize(torch.stack(nsteps, 1)), self.unnormalize(torch.stack(flows, 1)) - #return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) - - @torch.no_grad() - def interpolate(self, x1, x2, t = None, lam = 0.5): - b, *_, device = *x1.shape, x1.device - t = default(t, self.num_timesteps - 1) - - assert x1.shape == x2.shape - - t_batched = torch.stack([torch.tensor(t, device = device)] * b) - xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) - - img = (1 - lam) * xt1 + lam * xt2 - for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) - - return img - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - @property - def loss_fn(self): - if self.loss_type == 'l1': - return F.l1_loss - elif self.loss_type == 'l2': - return F.mse_loss - else: - raise ValueError(f'invalid loss type {self.loss_type}') - - #def p_losses(self, x_start, t, noise = None): - def p_losses(self, stack, hres, lres, t, noise = None): - - b, f, c, h, w = hres.shape - - stack = rearrange(stack, 'b t c h w -> (b t) c h w') - hres_flow = rearrange(hres[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - - flow, context = self.flow(stack) - #print(flow.shape, hres_flow.shape) - warped = scale_space_warp(hres_flow, flow) - - x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') - x_start = x_start - warped - - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - - b, c, h, w = x_start.shape - - del f - - noise = default(noise, lambda: torch.randn_like(x_start)) - - # noise sample - - x = self.q_sample(x_start = x_start, t = t, noise = noise) - - # if doing self-conditioning, 50% of the time, predict x_start from current set of times - # and condition with unet with that - # this technique will slow down training by 25%, but seems to lower FID significantly - - x_self_cond = None - if self.self_condition and random() < 0.5: - with torch.no_grad(): - x_self_cond = self.model_predictions(x, t).pred_x_start - x_self_cond.detach_() - - # predict and take gradient step - - model_out = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) - - if self.objective == 'pred_noise': - target = noise - elif self.objective == 'pred_x0': - target = x_start - elif self.objective == 'pred_v': - v = self.predict_v(x_start, t, noise) - target = v - else: - raise ValueError(f'unknown objective {self.objective}') - - loss = self.loss_fn(model_out, target, reduction = 'none') - loss = reduce(loss, 'b ... -> b (...)', 'mean') - - loss = loss * extract(self.p2_loss_weight, t, loss.shape) - return loss.mean() - - #def forward(self, data, *args, **kwargs): - def forward(self, lres, hres, *args, **kwargs): - - #b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size - b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size - - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - - #t = torch.randint(0, self.num_timesteps, (b,), device=device).long() - t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() - - #img = self.normalize(img) - lres = self.normalize(lres) - hres = self.normalize(hres) - - l = hres.clone() - r = torch.roll(l, -1, 1) - - m = lres.clone() - m1 = rearrange(m, 'b t c h w -> (b t) c h w') - m1 = self.upsample(m1) - m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) - m1 = torch.roll(m1, -2, 1) - - stack = torch.cat((l, r, m1), 2) - stack = stack[:, :(f-2), :, :, :] - - #return self.p_losses(img, t, *args, **kwargs) - return self.p_losses(stack, hres, lres, t, *args, **kwargs) - -# trainer class - -class Trainer(object): - def __init__( - self, - diffusion_model, - train_dl, - val_dl, - config, - *, - train_batch_size = 16, - gradient_accumulate_every = 1, - #augment_horizontal_flip = True, - train_lr = 1e-4, - train_num_steps = 100000, - ema_update_every = 1, - ema_decay = 0.995, - adam_betas = (0.9, 0.99), - save_and_sample_every = 10, - #num_samples = 25, - eval_folder = './evaluate', - results_folder = './results', - #tensorboard_dir = './tensorboard', - val_num_of_batch = 2, - amp = False, - fp16 = False, - #fp16 = True, - split_batches = True, - #split_batches = False, - convert_image_to = None - ): - super().__init__() - - self.accelerator = Accelerator( - split_batches = split_batches, - mixed_precision = 'fp16' if fp16 else 'no', - log_with = 'wandb', - ) - self.accelerator.init_trackers("vsr-orig-autoreg-hres", - init_kwargs={ - "wandb": { - "notes": "Use VSR to improve precipitation forecasting.", - # Change "name" to set the name of the run. - "name": None, - } - }, - ) - self.config = config - self.accelerator.native_amp = amp - - self.model = diffusion_model - - self.save_and_sample_every = save_and_sample_every - - self.batch_size = train_batch_size - self.gradient_accumulate_every = gradient_accumulate_every - - self.train_num_steps = train_num_steps - self.image_size = diffusion_model.image_size - - self.val_num_of_batch = val_num_of_batch - - # optimizer - - self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) - - # for logging results in a folder periodically - - if self.accelerator.is_main_process: - self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) - - self.results_folder = Path(results_folder) - - self.results_folder.mkdir(exist_ok=True, parents=True) - - self.eval_folder = eval_folder - - # step counter state - - self.step = 0 - - # prepare model, dataloader, optimizer with accelerator - - self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) - self.train_dl = cycle(train_dl) - self.val_dl = cycle(val_dl) - - def save(self, milestone): - if not self.accelerator.is_local_main_process: - return - - data = { - 'step': self.step, - 'model': self.accelerator.get_state_dict(self.model), - 'opt': self.opt.state_dict(), - 'ema': self.ema.state_dict(), - 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, - #'version': __version__ - } - - torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) - - def load(self, milestone): - accelerator = self.accelerator - device = accelerator.device - - data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) - - model = self.accelerator.unwrap_model(self.model) - model.load_state_dict(data['model']) - - self.step = data['step'] - #self.opt.load_state_dict(data['opt']) - self.ema.load_state_dict(data['ema']) - - #if 'version' in data: - # print(f"loading from version {data['version']}") - - if exists(self.accelerator.scaler) and exists(data['scaler']): - self.accelerator.scaler.load_state_dict(data['scaler']) - - def train(self): - - accelerator = self.accelerator - device = accelerator.device - - with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: - - while self.step < self.train_num_steps: - - total_loss = 0. - - for _ in range(self.gradient_accumulate_every): - - #data = next(self.dl).to(device) - data = next(self.train_dl) - lres = data['LR'].to(device) - hres = data['HR'].to(device) - - with self.accelerator.autocast(): - - #loss = self.model(data) - loss = self.model(lres, hres) - loss = loss / self.gradient_accumulate_every - total_loss += loss.item() - - self.accelerator.backward(loss) - - accelerator.clip_grad_norm_(self.model.parameters(), 1.0) - pbar.set_description(f'loss: {total_loss:.4f}') - - #self.writer.add_scalar("loss", total_loss, self.step) - accelerator.log({"loss": total_loss}, step = self.step) - - accelerator.wait_for_everyone() - - self.opt.step() - self.opt.zero_grad() - - accelerator.wait_for_everyone() - - self.step += 1 - if accelerator.is_main_process: - self.ema.to(device) - self.ema.update() - - if self.step != 0 and self.step % self.save_and_sample_every == 0: - self.ema.ema_model.eval() - - with torch.no_grad(): - - for i, batch in enumerate(self.val_dl): - - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if i >= self.val_num_of_batch: - break - - videos, base, res, flows = self.ema.ema_model.sample(lres, hres) - psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') - - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flows": wandb.Video((flows.clamp(0.0, 1.0).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"psnr": psnr_index.mean()}, step=self.step) - - milestone = self.step // self.save_and_sample_every - - self.save(milestone) - - pbar.update(1) - - accelerator.print('training complete') - - def sample(self): - - accelerator = self.accelerator - device = accelerator.device - - self.ema.ema_model.eval() - - cmap = mpl.colormaps['viridis'] - sm = smap(None, cmap) - - with torch.no_grad(): - - for k, batch in enumerate(self.val_dl): - - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if k >= self.val_num_of_batch: - break - - limit = lres.shape[1] - if limit < 8: - - #videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres, True) - videos, base, nsteps, flows = self.ema.ema_model.sample(lres, hres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen.pt") - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr.pt") - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr.pt") - - for i, b in enumerate(videos.clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - os.makedirs(os.path.join(self.eval_folder, "generated")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - - #videos = torch.log(videos.clamp(0.0, 1.0) + 1) - #hres = torch.log(hres + 1) - - #for i, b in enumerate(videos.clamp(0, 1)): - # for i, b in enumerate(videos): - # if not os.path.isdir(os.path.join(self.eval_folder, "generated")): - # os.makedirs(os.path.join(self.eval_folder, "generated")) - # Parallel(n_jobs=4)( - # delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") - # for j, f in enumerate(b.cpu()) - # ) - -# for i, b in enumerate(nsteps.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# os.makedirs(os.path.join(self.eval_folder, "residual")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(base.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# os.makedirs(os.path.join(self.eval_folder, "warped")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# os.makedirs(os.path.join(self.eval_folder, "flows")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - - for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): - if not os.path.isdir(os.path.join(self.eval_folder, "truth")): - os.makedirs(os.path.join(self.eval_folder, "truth")) - Parallel(n_jobs=4)( - delayed(save_image)(sm.to_rgba(f[0,:,:]), os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") - for j, f in enumerate(b.cpu()) - ) - -# else: - -# videos, base, nsteps, flows = self.ema.ema_model.sample(lres[:,:7,:,:], hres[:,:7,:,:], True) - -# st = 5 -# ed = st + 7 - -# while ed < limit: - -# vi, ba, ns, fl = self.ema.ema_model.sample(lres[:,st:ed,:,:], hres[:,st:ed,:,:], True) -# st += 5 -# ed += 5 -# videos = torch.cat((videos, vi), 1) -# #base = torch.cat((base, ba), 1) -# #nsteps = torch.cat((nsteps, ns), 1) -# #flows = torch.cat((flows, fl), 1) - -# for i, b in enumerate(videos.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "generated")): -# os.makedirs(os.path.join(self.eval_folder, "generated")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "generated") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# for i, b in enumerate(hres[:,2:,:,:,:].clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "truth")): -# os.makedirs(os.path.join(self.eval_folder, "truth")) -# Parallel(n_jobs=4)( -# delayed(save_image)(f, os.path.join(self.eval_folder, "truth") + f"/{k}-{i}-{j}.png") -# for j, f in enumerate(b.cpu()) -# ) - -# # for i, b in enumerate(nsteps.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "residual")): -# # os.makedirs(os.path.join(self.eval_folder, "residual")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "residual") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(base.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "warped")): -# # os.makedirs(os.path.join(self.eval_folder, "warped")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "warped") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# # for i, b in enumerate(flows.clamp(0, 1)): -# # #for i, b in enumerate(sampled): -# # if not os.path.isdir(os.path.join(self.eval_folder, "flows")): -# # os.makedirs(os.path.join(self.eval_folder, "flows")) -# # Parallel(n_jobs=4)( -# # delayed(save_image)(f, os.path.join(self.eval_folder, "flows") + f"/{k}-{i}-{j}.png") -# # for j, f in enumerate(b.cpu()) -# # ) - -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_d")): -# os.makedirs(os.path.join(self.eval_folder, "flows_d")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_d") + f"/{k}-{i}-{j}.png", flow_vis.flow_to_color(f.permute(1,2,0).cpu().numpy()[:,:,:2], convert_to_bgr = False)) -# for j, f in enumerate(b.cpu()) -# ) -# for i, b in enumerate(flows.clamp(0, 1)): -# #for i, b in enumerate(sampled): -# if not os.path.isdir(os.path.join(self.eval_folder, "flows_s")): -# os.makedirs(os.path.join(self.eval_folder, "flows_s")) -# Parallel(n_jobs=4)( -# delayed(plt.imsave)(os.path.join(self.eval_folder, "flows_s") + f"/{k}-{i}-{j}.png", f.permute(1,2,0).cpu().numpy()[:,:,2], cmap = 'gray_r') -# for j, f in enumerate(b.cpu()) -# ) \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index e6bb3d905f..a0ee69f27a 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -1,7 +1,7 @@ import os import math from pathlib import Path -from random import random +from random import random, randint from functools import partial from collections import namedtuple @@ -12,6 +12,8 @@ import torch.nn.functional as F import wandb +from torchvision.transforms.functional import crop + import piq from kornia import filters @@ -41,6 +43,14 @@ # helpers functions +def get_random_idx_with_difference(min_tx, max_tx, number_tx, diff): + times = [] + while len(times) < number_tx: + new_time = randint(min_tx, max_tx) + if all(abs(new_time - time) >= diff for time in times): + times.append(new_time) + return times + def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') for i in range(256): @@ -100,7 +110,7 @@ def normalize_to_neg_one_to_one(img): def unnormalize_to_zero_to_one(t): return (t + 1) * 0.5 -# ssf modules +# flow modules def gaussian_pyramids(input, base_sigma = 1, m = 5): @@ -164,6 +174,41 @@ def scale_space_warp(input, flow): return warped +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + # small helper modules class Residual(nn.Module): @@ -640,6 +685,7 @@ def __init__( flow, *, image_size, + in_ch, timesteps = 1200, sampling_timesteps = None, loss_type = 'l1', @@ -652,16 +698,13 @@ def __init__( auto_normalize = True ): super().__init__() - #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) - #assert not model.random_or_learned_sinusoidal_cond self.model = model - self.umodel = context_net(upscale=8, in_chans=1, img_size=48, window_size=8, - img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, - num_heads=[8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv') - + self.umodel = context_net(upscale = 8, in_chans = in_ch, out_chans = 1, img_size = 48, window_size = 8, + img_range = 1., depths = [6, 6, 6, 6, 6, 6, 6], embed_dim = 200, + num_heads = [8, 8, 8, 8, 8, 8, 8], + mlp_ratio = 2, upsampler = 'pixelshuffle', resi_connection = '3conv') self.flow = flow self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) @@ -765,6 +808,7 @@ def predict_start_from_v(self, x_t, t, v): ) def q_posterior(self, x_start, x_t, t): + posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t @@ -773,11 +817,8 @@ def q_posterior(self, x_start, x_t, t): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): - #model_output = self.model(x, t, x_self_cond) - #print(x.shape, l_cond.shape) model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity @@ -800,10 +841,8 @@ def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_st return ModelPrediction(pred_noise, x_start) - #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): - #preds = self.model_predictions(x, t, x_self_cond) preds = self.model_predictions(x, t, context, x_self_cond) x_start = preds.pred_x_start @@ -814,22 +853,18 @@ def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = Tru return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - #def p_sample(self, x, t: int, x_self_cond = None): def p_sample(self, x, t: int, context, x_self_cond = None): - b, *_, device = *x.shape, x.device batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) - #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.no_grad() - #def p_sample_loop(self, shape, return_all_timesteps = False): def p_sample_loop(self, shape, context, return_all_timesteps = False): - batch, device = shape[0], self.betas.device + device = self.betas.device img = torch.randn(shape, device = device) imgs = [img] @@ -838,19 +873,16 @@ def p_sample_loop(self, shape, context, return_all_timesteps = False): for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): self_cond = x_start if self.self_condition else None - #img, x_start = self.p_sample(img, t, self_cond) img, x_start = self.p_sample(img, t, context, self_cond) imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - #def ddim_sample(self, shape, return_all_timesteps = False): def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): - print('here!!!') + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps @@ -865,7 +897,6 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) self_cond = x_start if self.self_condition else None - #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) imgs.append(img) @@ -889,33 +920,66 @@ def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - #ret = self.unnormalize(ret) return ret @torch.no_grad() - def sample(self, lres, return_all_timesteps = False): + def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): b, f, c, h, w = lres.shape - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + if multi: + + topo = hres[:, :, 1:2, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(8*h, 8*w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) lres = self.normalize(lres) ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + r = torch.roll(l, -1, 1) ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + + if multi: + + l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + + else: + + l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) stack = torch.cat((l, r, m1), 2) stack = stack[:, :(f-2), :, :, :] @@ -923,9 +987,15 @@ def sample(self, lres, return_all_timesteps = False): flow, context = self.flow(stack) - warped = scale_space_warp(ures_flow, flow) + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + warped = flow_warp(ures_flow, flow) - res = sample_fn((b*(f-2),c,8*h,8*w), l_cond, context, return_all_timesteps = return_all_timesteps) + res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) sres = warped + res sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) @@ -934,9 +1004,11 @@ def sample(self, lres, return_all_timesteps = False): flow = rearrange(flow, '(b t) c h w -> b t c h w', b = b) return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) + @torch.no_grad() def interpolate(self, x1, x2, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) @@ -952,6 +1024,7 @@ def interpolate(self, x1, x2, t = None, lam = 0.5): return img def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) return ( @@ -961,6 +1034,7 @@ def q_sample(self, x_start, t, noise=None): @property def loss_fn(self): + if self.loss_type == 'l1': return F.l1_loss elif self.loss_type == 'l2': @@ -968,24 +1042,33 @@ def loss_fn(self): else: raise ValueError(f'invalid loss type {self.loss_type}') - def p_losses(self, stack, hres, lres, ures, t, noise = None): + def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, noise = None): - b, f, c, h, w = hres.shape + f = hres.shape[1] stack = rearrange(stack, 'b t c h w -> (b t) c h w') - ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + ures_flow = rearrange(ures[:, 1:(f - 1), :, :, :], 'b t c h w -> (b t) c h w') flow, context = self.flow(stack) - warped = scale_space_warp(ures_flow, flow) + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + warped = flow_warp(ures_flow, flow) x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') x_start = x_start - warped - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + if multi: + + l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) - b, c, h, w = x_start.shape + else: + + l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) del f @@ -1032,35 +1115,68 @@ def p_losses(self, stack, hres, lres, ures, t, noise = None): return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 - def forward(self, lres, hres, *args, **kwargs): - - b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + def forward(self, lres, hres, multi, flow_mode, *args, **kwargs): - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + b, f, c, h, w, device = *hres.shape, hres.device t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + if multi: + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) lres = self.normalize(lres) hres = self.normalize(hres) ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + r = torch.roll(l, -1, 1) m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') m1 = self.upsample(m1) m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) stack = torch.cat((l, r, m1), 2) stack = stack[:, :(f-2), :, :, :] - return self.p_losses(stack, hres, lres, ures, t, *args, **kwargs) + if multi: + + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, topo, *args, **kwargs) + + else: + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, None, *args, **kwargs) # trainer class @@ -1074,24 +1190,18 @@ def __init__( *, train_batch_size = 16, gradient_accumulate_every = 1, - #augment_horizontal_flip = True, train_lr = 1e-4, train_num_steps = 100000, ema_update_every = 1, ema_decay = 0.995, adam_betas = (0.9, 0.99), save_and_sample_every = 1, - #num_samples = 25, eval_folder = './evaluate', results_folder = './results', - #tensorboard_dir = './tensorboard', val_num_of_batch = 2, amp = False, fp16 = False, - #fp16 = True, - split_batches = True, - #split_batches = False, - convert_image_to = None + split_batches = True ): super().__init__() @@ -1111,6 +1221,10 @@ def __init__( ) self.config = config self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] self.model = diffusion_model @@ -1193,7 +1307,6 @@ def train(self): cmap = mpl.colormaps['RdBu_r'] fcmap = mpl.colormaps['gray_r'] - c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') c384_gmin = np.load('data/only_precip/c384_gmin.npy') @@ -1216,15 +1329,20 @@ def train(self): for _ in range(self.gradient_accumulate_every): - #data = next(self.dl).to(device) data = next(self.train_dl) lres = data['LR'].to(device) hres = data['HR'].to(device) + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + with self.accelerator.autocast(): - #loss = self.model(data) - loss = self.model(lres, hres) + loss = self.model(lres, hres, self.multi, self.flow) loss = loss / self.gradient_accumulate_every total_loss += loss.item() @@ -1233,7 +1351,6 @@ def train(self): accelerator.clip_grad_norm_(self.model.parameters(), 1.0) pbar.set_description(f'loss: {total_loss:.4f}') - #self.writer.add_scalar("loss", total_loss, self.step) accelerator.log({"loss": total_loss}, step = self.step) accelerator.wait_for_everyone() @@ -1269,13 +1386,17 @@ def train(self): truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) for k in range(num_samples): - videos, base, res, flows = self.ema.ema_model.sample(lres) + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) + psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') videos_time_mean = videos.mean(dim = 1) @@ -1341,15 +1462,20 @@ def train(self): ax1.set_yscale("log") flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) + for m in range(num_frames): + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) sm = smap(None, fcmap) + for m in range(num_frames): + flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) if self.config.data_config.logscale: + accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) @@ -1359,8 +1485,9 @@ def train(self): accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) else: - accelerator.log({"true_high": wandb.Video((hres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) target = np.log(target - c384_gmin + 1e-14) @@ -1395,30 +1522,61 @@ def sample(self): self.ema.ema_model.eval() - c384_norm= torch.from_numpy(np.load("data/only_precip/c384_norm.npy")) - c48_norm = torch.from_numpy(np.load("data/only_precip/c48_norm.npy")) + c384_norm= torch.from_numpy(np.load("data/only_precip/c384_lgnorm.npy")) + c48_norm = torch.from_numpy(np.load("data/only_precip/c48_lgnorm.npy")) + + if self.multi: + + # refer to data/vsrdata.py for more info + + c48_norm_more = torch.from_numpy(np.load("data/more_channels/c48_norm.npy")) + c48_norm = torch.cat((c48_norm, c48_norm_more), 2) + + topo384 = torch.from_numpy(np.repeat(np.load("data/topography/topo384_norm.npy").reshape((6, 1, 1, 384, 384)), 2920, axis = 1)) + c384_norm = torch.cat((c384_norm, topo384), axis = 2) with torch.no_grad(): for tile in range(6): - st = 0 - en = 27 - count = 0 + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < c48_norm.shape[1]: + + print(tile, st) + + lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': - while en < c48_norm.shape[1]: + seq_len = self.rollout_batch + indices = get_random_idx_with_difference(0, c48_norm.shape[1] - (seq_len + 2), 250 // seq_len, seq_len + 2) # 250 samples per tile - print(tile, st) + for count, st in enumerate(indices): - lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) - hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) + print(tile, count) - videos, base, res, flows = self.ema.ema_model.sample(lres) + lres = c48_norm[tile,st:st+(seq_len+2),:,:,:].unsqueeze(0).to(device) + hres = c384_norm[tile,st:st+(seq_len+2),:,:,:].unsqueeze(0).to(device) - torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi) - st += 25 - en += 25 - count += 1 \ No newline at end of file + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod_flow.py b/projects/super_res/model/autoreg_diffusion_mod_flow.py deleted file mode 100644 index ff6f2764ef..0000000000 --- a/projects/super_res/model/autoreg_diffusion_mod_flow.py +++ /dev/null @@ -1,1390 +0,0 @@ -import os -import math -from pathlib import Path -from random import random -from functools import partial -from collections import namedtuple - -import numpy as np - -import torch -from torch import nn, einsum -import torch.nn.functional as F -import wandb - -import piq - -from torch.optim import Adam - -from einops import rearrange, reduce -from einops.layers.torch import Rearrange - -from PIL import Image - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.cm import ScalarMappable as smap - -from tqdm.auto import tqdm -from ema_pytorch import EMA - -import flow_vis - -from accelerate import Accelerator - -from .network_swinir import SwinIR as context_net - -# constants - -ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) - -# helpers functions - -def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): - truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - for i in range(256): - truth_cdf[i, :, :, :, :, :, :] = (truth <= i).astype('uint8') - pred_cdf = np.zeros((256, num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - for j in range(256): - pred_cdf[j, :, :, :, :, :, :, :] = (pred <= j).astype('uint8') - red_pred_cdf = pred_cdf.mean(1) - temp = np.square(red_pred_cdf - truth_cdf) - temp_dz = temp.sum(0) - temp_dz_dd = temp_dz.mean(axis = (3, 4, 5)) - temp_dz_dd_dt = temp_dz_dd.mean(2) - return temp_dz_dd_dt.mean() - -def save_image(tensor, path): - im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) - im.save(path) - return None - -def exists(x): - return x is not None - -def default(val, d): - if exists(val): - return val - return d() if callable(d) else d - -def identity(t, *args, **kwargs): - return t - -def cycle(dl): - while True: - for data in dl: - yield data - -def has_int_squareroot(num): - return (math.sqrt(num) ** 2) == num - -def num_to_groups(num, divisor): - groups = num // divisor - remainder = num % divisor - arr = [divisor] * groups - if remainder > 0: - arr.append(remainder) - return arr - -def convert_image_to_fn(img_type, image): - if image.mode != img_type: - return image.convert(img_type) - return image - -# normalization functions - -def normalize_to_neg_one_to_one(img): - return img * 2 - 1 - -def unnormalize_to_zero_to_one(t): - return (t + 1) * 0.5 - -# flow modules - -def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): - """Warp an image or feature map with optical flow. - - Args: - x (Tensor): Tensor with size (n, c, h, w). - flow (Tensor): Tensor with size (n, h, w, 2), normal value. - interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. - padding_mode (str): 'zeros' or 'border' or 'reflection'. - Default: 'zeros'. - align_corners (bool): Before pytorch 1.3, the default value is - align_corners=True. After pytorch 1.3, the default value is - align_corners=False. Here, we use the True as default. - - - Returns: - Tensor: Warped image or feature map. - """ - n, _, h, w = x.size() - # create mesh grid - grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), - torch.arange(0, w, dtype=x.dtype, device=x.device)) - grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 - grid.requires_grad = False - - vgrid = grid + flow - - # scale grid to [-1,1] - vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 - vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 - vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) - - output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) - - return output - -# small helper modules - -class Residual(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, x, *args, **kwargs): - return self.fn(x, *args, **kwargs) + x - -def Upsample(dim, dim_out = None): - return nn.Sequential( - nn.Upsample(scale_factor = 2, mode = 'nearest'), - nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) - ) - -def Downsample(dim, dim_out = None): - return nn.Sequential( - Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), - nn.Conv2d(dim * 4, default(dim_out, dim), 1) - ) - -class WeightStandardizedConv2d(nn.Conv2d): - """ - https://arxiv.org/abs/1903.10520 - weight standardization purportedly works synergistically with group normalization - """ - def forward(self, x): - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - - weight = self.weight - mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') - var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) - normalized_weight = (weight - mean) * (var + eps).rsqrt() - - return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - -class LayerNorm(nn.Module): - def __init__(self, dim): - super().__init__() - self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) - - def forward(self, x): - eps = 1e-5 if x.dtype == torch.float32 else 1e-3 - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) - return (x - mean) * (var + eps).rsqrt() * self.g - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.fn = fn - self.norm = LayerNorm(dim) - - def forward(self, x): - x = self.norm(x) - return self.fn(x) - -# sinusoidal positional embeds - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - -class RandomOrLearnedSinusoidalPosEmb(nn.Module): - """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ - """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ - - def __init__(self, dim, is_random = False): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) - - def forward(self, x): - x = rearrange(x, 'b -> b 1') - freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) - fouriered = torch.cat((x, fouriered), dim = -1) - return fouriered - -# building block modules - -class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): - super().__init__() - self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) - self.act = nn.SiLU() - - def forward(self, x, scale_shift = None): - x = self.proj(x) - x = self.norm(x) - - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - - x = self.act(x) - return x - -class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): - super().__init__() - self.mlp = nn.Sequential( - nn.SiLU(), - nn.Linear(time_emb_dim, dim_out * 2) - ) if exists(time_emb_dim) else None - - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) - self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - - def forward(self, x, time_emb = None): - - scale_shift = None - if exists(self.mlp) and exists(time_emb): - time_emb = self.mlp(time_emb) - time_emb = rearrange(time_emb, 'b c -> b c 1 1') - scale_shift = time_emb.chunk(2, dim = 1) - - h = self.block1(x, scale_shift = scale_shift) - - h = self.block2(h) - - return h + self.res_conv(x) - -class LinearAttention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): - super().__init__() - self.scale = dim_head ** -0.5 - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - - self.to_out = nn.Sequential( - nn.Conv2d(hidden_dim, dim, 1), - LayerNorm(dim) - ) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) - - q = q.softmax(dim = -2) - k = k.softmax(dim = -1) - - q = q * self.scale - v = v / (h * w) - - context = torch.einsum('b h d n, b h e n -> b h d e', k, v) - - out = torch.einsum('b h d e, b h d n -> b h e n', context, q) - out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) - return self.to_out(out) - -class Attention(nn.Module): - def __init__(self, dim, heads = 4, dim_head = 32): - super().__init__() - self.scale = dim_head ** -0.5 - self.heads = heads - hidden_dim = dim_head * heads - - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x).chunk(3, dim = 1) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) - - q = q * self.scale - - sim = einsum('b h d i, b h d j -> b h i j', q, k) - attn = sim.softmax(dim = -1) - out = einsum('b h i j, b h d j -> b h i d', attn, v) - - out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) - return self.to_out(out) - -# model - -class Unet(nn.Module): - def __init__( - self, - dim, - init_dim = None, - out_dim = None, - dim_mults=(1, 2, 4, 8), - channels = 3, - self_condition = False, - resnet_block_groups = 8, - learned_variance = False, - learned_sinusoidal_cond = False, - random_fourier_features = False, - learned_sinusoidal_dim = 16 - ): - super().__init__() - - # determine dimensions - - self.channels = channels - self.self_condition = self_condition - input_channels = channels * (2 if self_condition else 1) - - init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - - # time embeddings - - time_dim = dim * 4 - - self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features - - if self.random_or_learned_sinusoidal_cond: - sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) - fourier_dim = learned_sinusoidal_dim + 1 - else: - sinu_pos_emb = SinusoidalPosEmb(dim) - fourier_dim = dim - - self.time_mlp = nn.Sequential( - sinu_pos_emb, - nn.Linear(fourier_dim, time_dim), - nn.GELU(), - nn.Linear(time_dim, time_dim) - ) - - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.downs.append(nn.ModuleList([ - block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), - block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - default_out_dim = channels * (1 if not learned_variance else 2) - self.out_dim = default(out_dim, default_out_dim) - - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) - self.final_conv = nn.Conv2d(dim, self.out_dim, 1) - - def forward(self, x, time, context, x_self_cond = None): - if self.self_condition: - x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) - x = torch.cat((x_self_cond, x), dim = 1) - - x = self.init_conv(x) - r = x.clone() - - t = self.time_mlp(time) - - h = [] - - count = 0 - - for block1, block2, attn, downsample in self.downs: - x = torch.cat((x, context[count]), dim = 1) - count += 1 - x = block1(x, t) - h.append(x) - - x = torch.cat((x, context[count]), dim = 1) - count += 1 - x = block2(x, t) - x = attn(x) - h.append(x) - - x = downsample(x) - - x = self.mid_block1(x, t) - x = self.mid_attn(x) - x = self.mid_block2(x, t) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) - x = block1(x, t) - - x = torch.cat((x, h.pop()), dim = 1) - x = block2(x, t) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim = 1) - - x = self.final_res_block(x, t) - return self.final_conv(x) - -class Flow(nn.Module): - def __init__( - self, - dim, - init_dim = None, - out_dim = None, - dim_mults=(1, 2, 4, 8), - channels = 3, - resnet_block_groups = 8, - ): - super().__init__() - - # determine dimensions - - self.channels = channels - input_channels = channels - - init_dim = default(init_dim, dim) - self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) - - dims = [init_dim, *map(lambda m: dim * m, dim_mults)] - in_out = list(zip(dims[:-1], dims[1:])) - - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - - # layers - - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) - num_resolutions = len(in_out) - - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (num_resolutions - 1) - - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in), - block_klass(dim_in, dim_in), - Residual(PreNorm(dim_in, LinearAttention(dim_in))), - Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) - ])) - - mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim) - self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim) - - for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): - is_last = ind == (len(in_out) - 1) - - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out), - block_klass(dim_out + dim_in, dim_out), - Residual(PreNorm(dim_out, LinearAttention(dim_out))), - Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) - ])) - - default_out_dim = channels - self.out_dim = default(out_dim, default_out_dim) - - self.final_res_block = block_klass(dim * 2, dim) - self.final_conv = nn.Conv2d(dim, self.out_dim, 1) - - def forward(self, x): - - x = self.init_conv(x) - r = x.clone() - - h = [] - context = [] - for block1, block2, attn, downsample in self.downs: - x = block1(x) - h.append(x) - context.append(x) - x = block2(x) - x = attn(x) - h.append(x) - context.append(x) - x = downsample(x) - - x = self.mid_block1(x) - x = self.mid_attn(x) - x = self.mid_block2(x) - - for block1, block2, attn, upsample in self.ups: - x = torch.cat((x, h.pop()), dim = 1) - x = block1(x) - - x = torch.cat((x, h.pop()), dim = 1) - x = block2(x) - x = attn(x) - - x = upsample(x) - - x = torch.cat((x, r), dim = 1) - - x = self.final_res_block(x) - x = F.tanh(self.final_conv(x)) - - return x, context - #return self.final_conv(x), context - -# gaussian diffusion trainer class - -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - -def linear_beta_schedule(timesteps): - """ - linear schedule, proposed in original ddpm paper - """ - scale = 1000 / timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) - -def cosine_beta_schedule(timesteps, s = 0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps - alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - -def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): - """ - sigmoid schedule - proposed in https://arxiv.org/abs/2212.11972 - Figure 8 - better for images > 64x64, when used during training - """ - steps = timesteps + 1 - t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps - v_start = torch.tensor(start / tau).sigmoid() - v_end = torch.tensor(end / tau).sigmoid() - alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - -class GaussianDiffusion(nn.Module): - def __init__( - self, - model, - flow, - *, - image_size, - timesteps = 1200, - sampling_timesteps = None, - loss_type = 'l1', - objective = 'pred_noise', - beta_schedule = 'sigmoid', - schedule_fn_kwargs = dict(), - p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended - p2_loss_weight_k = 1, - ddim_sampling_eta = 0., - auto_normalize = True - ): - super().__init__() - #assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) - #assert not model.random_or_learned_sinusoidal_cond - - self.model = model - - self.umodel = context_net(upscale=8, in_chans=1, img_size=48, window_size=8, - img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, - num_heads=[8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv') - - self.flow = flow - self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) - - self.channels = self.model.channels - self.self_condition = self.model.self_condition - - self.image_size = image_size - - self.objective = objective - - assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' - - if beta_schedule == 'linear': - beta_schedule_fn = linear_beta_schedule - elif beta_schedule == 'cosine': - beta_schedule_fn = cosine_beta_schedule - elif beta_schedule == 'sigmoid': - beta_schedule_fn = sigmoid_beta_schedule - else: - raise ValueError(f'unknown beta schedule {beta_schedule}') - - betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) - - alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - # sampling related parameters - - self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training - - assert self.sampling_timesteps <= timesteps - self.is_ddim_sampling = self.sampling_timesteps < timesteps - self.ddim_sampling_eta = ddim_sampling_eta - - # helper function to register buffer from float64 to float32 - - register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) - - register_buffer('betas', betas) - register_buffer('alphas_cumprod', alphas_cumprod) - register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) - register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) - register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) - register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) - register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - register_buffer('posterior_variance', posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) - register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) - register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) - - # calculate p2 reweighting - - register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) - - # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False - - self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity - self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def predict_noise_from_start(self, x_t, t, x0): - return ( - (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - ) - - def predict_v(self, x_start, t, noise): - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start - ) - - def predict_start_from_v(self, x_t, t, v): - return ( - extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - #def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False): - def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): - - #model_output = self.model(x, t, x_self_cond) - #print(x.shape, l_cond.shape) - model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) - - maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity - - if self.objective == 'pred_noise': - pred_noise = model_output - x_start = self.predict_start_from_noise(x, t, pred_noise) - x_start = maybe_clip(x_start) - - elif self.objective == 'pred_x0': - x_start = model_output - x_start = maybe_clip(x_start) - pred_noise = self.predict_noise_from_start(x, t, x_start) - - elif self.objective == 'pred_v': - v = model_output - x_start = self.predict_start_from_v(x, t, v) - x_start = maybe_clip(x_start) - pred_noise = self.predict_noise_from_start(x, t, x_start) - - return ModelPrediction(pred_noise, x_start) - - #def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): - def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): - - #preds = self.model_predictions(x, t, x_self_cond) - preds = self.model_predictions(x, t, context, x_self_cond) - x_start = preds.pred_x_start - - if clip_denoised: - x_start.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) - return model_mean, posterior_variance, posterior_log_variance, x_start - - @torch.no_grad() - #def p_sample(self, x, t: int, x_self_cond = None): - def p_sample(self, x, t: int, context, x_self_cond = None): - - b, *_, device = *x.shape, x.device - batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) - #model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) - model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) - noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 - pred_img = model_mean + (0.5 * model_log_variance).exp() * noise - return pred_img, x_start - - @torch.no_grad() - #def p_sample_loop(self, shape, return_all_timesteps = False): - def p_sample_loop(self, shape, context, return_all_timesteps = False): - - batch, device = shape[0], self.betas.device - - img = torch.randn(shape, device = device) - imgs = [img] - - x_start = None - - for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): - self_cond = x_start if self.self_condition else None - #img, x_start = self.p_sample(img, t, self_cond) - img, x_start = self.p_sample(img, t, context, self_cond) - imgs.append(img) - - ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - - #ret = self.unnormalize(ret) - return ret - - @torch.no_grad() - #def ddim_sample(self, shape, return_all_timesteps = False): - def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): - print('here!!!') - batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective - - times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps - times = list(reversed(times.int().tolist())) - time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] - - img = torch.randn(shape, device = device) - imgs = [img] - - x_start = None - - for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): - time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - self_cond = x_start if self.self_condition else None - #pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True) - pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) - - imgs.append(img) - - if time_next < 0: - img = x_start - continue - - alpha = self.alphas_cumprod[time] - alpha_next = self.alphas_cumprod[time_next] - - sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() - c = (1 - alpha_next - sigma ** 2).sqrt() - - noise = torch.randn_like(img) - - img = x_start * alpha_next.sqrt() + \ - c * pred_noise + \ - sigma * noise - - imgs.append(img) - ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) - - #ret = self.unnormalize(ret) - return ret - - @torch.no_grad() - def sample(self, lres, return_all_timesteps = False): - - b, f, c, h, w = lres.shape - - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) - ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) - - lres = self.normalize(lres) - ures = self.normalize(ures) - sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample - - l = ures.clone() - r = torch.roll(l, -1, 1) - ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') - - m = lres.clone() - m1 = rearrange(m, 'b t c h w -> (b t) c h w') - m1 = self.upsample(m1) - m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) - m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) - - stack = torch.cat((l, r, m1), 2) - stack = stack[:, :(f-2), :, :, :] - stack = rearrange(stack, 'b t c h w -> (b t) c h w') - - flow, context = self.flow(stack) - - flow = self.unnormalize(flow) - warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) - - res = sample_fn((b*(f-2),c,8*h,8*w), l_cond, context, return_all_timesteps = return_all_timesteps) - sres = warped + res - sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) - - warped = rearrange(warped, '(b t) c h w -> b t c h w', b = b) - res = rearrange(res, '(b t) c h w -> b t c h w', b = b) - flow = rearrange(flow, '(b t) c h w -> b t c h w', t = f-2) - - return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow - - @torch.no_grad() - def interpolate(self, x1, x2, t = None, lam = 0.5): - b, *_, device = *x1.shape, x1.device - t = default(t, self.num_timesteps - 1) - - assert x1.shape == x2.shape - - t_batched = torch.stack([torch.tensor(t, device = device)] * b) - xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) - - img = (1 - lam) * xt1 + lam * xt2 - for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) - - return img - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - @property - def loss_fn(self): - if self.loss_type == 'l1': - return F.l1_loss - elif self.loss_type == 'l2': - return F.mse_loss - else: - raise ValueError(f'invalid loss type {self.loss_type}') - - def p_losses(self, stack, hres, lres, ures, t, noise = None): - - b, f, c, h, w = hres.shape - - stack = rearrange(stack, 'b t c h w -> (b t) c h w') - ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - - flow, context = self.flow(stack) - - flow = self.unnormalize(flow) - warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) - - x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') - x_start = x_start - warped - - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) - #l_cond = rearrange(ures[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') - - b, c, h, w = x_start.shape - - del f - - noise = default(noise, lambda: torch.randn_like(x_start)) - - # noise sample - - x = self.q_sample(x_start = x_start, t = t, noise = noise) - - # if doing self-conditioning, 50% of the time, predict x_start from current set of times - # and condition with unet with that - # this technique will slow down training by 25%, but seems to lower FID significantly - - x_self_cond = None - if self.self_condition and random() < 0.5: - with torch.no_grad(): - x_self_cond = self.model_predictions(x, t).pred_x_start - x_self_cond.detach_() - - # predict and take gradient step - - model_out = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) - - if self.objective == 'pred_noise': - target = noise - elif self.objective == 'pred_x0': - target = x_start - elif self.objective == 'pred_v': - v = self.predict_v(x_start, t, noise) - target = v - else: - raise ValueError(f'unknown objective {self.objective}') - - loss = self.loss_fn(model_out, target, reduction = 'none') - loss = reduce(loss, 'b ... -> b (...)', 'mean') - - loss = loss * extract(self.p2_loss_weight, t, loss.shape) - - loss1 = self.loss_fn(ures, hres, reduction = 'none') - loss1 = reduce(loss1, 'b ... -> b (...)', 'mean') - - loss2 = self.loss_fn(x_start, warped, reduction = 'none') - loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') - - return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 - - def forward(self, lres, hres, *args, **kwargs): - - b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size - - assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - - t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() - - ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) - ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) - - lres = self.normalize(lres) - hres = self.normalize(hres) - ures = self.normalize(ures) - - l = ures.clone() - r = torch.roll(l, -1, 1) - - m = lres.clone() - m1 = rearrange(m, 'b t c h w -> (b t) c h w') - m1 = self.upsample(m1) - m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) - m1 = torch.roll(m1, -2, 1) - #m1 = torch.roll(l, -2, 1) - - stack = torch.cat((l, r, m1), 2) - stack = stack[:, :(f-2), :, :, :] - - return self.p_losses(stack, hres, lres, ures, t, *args, **kwargs) - -# trainer class - -class Trainer(object): - def __init__( - self, - diffusion_model, - train_dl, - val_dl, - config, - *, - train_batch_size = 16, - gradient_accumulate_every = 1, - #augment_horizontal_flip = True, - train_lr = 1e-4, - train_num_steps = 100000, - ema_update_every = 1, - ema_decay = 0.995, - adam_betas = (0.9, 0.99), - save_and_sample_every = 1, - #num_samples = 25, - eval_folder = './evaluate', - results_folder = './results', - #tensorboard_dir = './tensorboard', - val_num_of_batch = 2, - amp = False, - fp16 = False, - #fp16 = True, - split_batches = True, - #split_batches = False, - convert_image_to = None - ): - super().__init__() - - self.accelerator = Accelerator( - split_batches = split_batches, - mixed_precision = 'fp16' if fp16 else 'no', - log_with = 'wandb', - ) - self.accelerator.init_trackers("vsr-orig-autoreg-hres", - init_kwargs={ - "wandb": { - "notes": "Use VSR to improve precipitation forecasting.", - # Change "name" to set the name of the run. - "name": None, - } - }, - ) - self.config = config - self.accelerator.native_amp = amp - - self.model = diffusion_model - - self.save_and_sample_every = save_and_sample_every - - self.batch_size = train_batch_size - self.gradient_accumulate_every = gradient_accumulate_every - - self.train_num_steps = train_num_steps - self.image_size = diffusion_model.image_size - - self.val_num_of_batch = val_num_of_batch - - # optimizer - - self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) - - # for logging results in a folder periodically - - if self.accelerator.is_main_process: - self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) - - self.results_folder = Path(results_folder) - - self.results_folder.mkdir(exist_ok=True, parents=True) - - self.eval_folder = eval_folder - - # step counter state - - self.step = 0 - - # prepare model, dataloader, optimizer with accelerator - - self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) - self.train_dl = cycle(train_dl) - self.val_dl = cycle(val_dl) - - def save(self, milestone): - if not self.accelerator.is_local_main_process: - return - - data = { - 'step': self.step, - 'model': self.accelerator.get_state_dict(self.model), - 'opt': self.opt.state_dict(), - 'ema': self.ema.state_dict(), - 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, - #'version': __version__ - } - - torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) - - def load(self, milestone): - accelerator = self.accelerator - device = accelerator.device - - data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) - - model = self.accelerator.unwrap_model(self.model) - model.load_state_dict(data['model']) - - self.step = data['step'] - #self.opt.load_state_dict(data['opt']) - self.ema.load_state_dict(data['ema']) - - #if 'version' in data: - # print(f"loading from version {data['version']}") - - if exists(self.accelerator.scaler) and exists(data['scaler']): - self.accelerator.scaler.load_state_dict(data['scaler']) - - def train(self): - - accelerator = self.accelerator - device = accelerator.device - - cmap = mpl.colormaps['RdBu_r'] - - c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') - c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') - c384_gmin = np.load('data/only_precip/c384_gmin.npy') - - c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') - c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') - c48_gmin = np.load('data/only_precip/c48_gmin.npy') - - c384_min = np.load('data/only_precip/c384_min.npy') - c384_max = np.load('data/only_precip/c384_max.npy') - - c48_min = np.load('data/only_precip/c48_min.npy') - c48_max = np.load('data/only_precip/c48_max.npy') - - with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: - - while self.step < self.train_num_steps: - - total_loss = 0. - - for _ in range(self.gradient_accumulate_every): - - #data = next(self.dl).to(device) - data = next(self.train_dl) - lres = data['LR'].to(device) - hres = data['HR'].to(device) - - with self.accelerator.autocast(): - - #loss = self.model(data) - loss = self.model(lres, hres) - loss = loss / self.gradient_accumulate_every - total_loss += loss.item() - - self.accelerator.backward(loss) - - accelerator.clip_grad_norm_(self.model.parameters(), 1.0) - pbar.set_description(f'loss: {total_loss:.4f}') - - #self.writer.add_scalar("loss", total_loss, self.step) - accelerator.log({"loss": total_loss}, step = self.step) - - accelerator.wait_for_everyone() - - self.opt.step() - self.opt.zero_grad() - - accelerator.wait_for_everyone() - - self.step += 1 - if accelerator.is_main_process: - self.ema.to(device) - self.ema.update() - - if self.step != 0 and self.step % self.save_and_sample_every == 0: - self.ema.ema_model.eval() - - with torch.no_grad(): - - for i, batch in enumerate(self.val_dl): - - lres = batch['LR'].to(device) - hres = batch['HR'].to(device) - - if i >= self.val_num_of_batch: - break - - num_samples = 5 - num_videos_per_batch = 1 - num_frames = 5 - img_size = 384 - img_channels = 1 - - truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - truth[0,:,:,:,:,:] = (hres[:,2:,:,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) - - for k in range(num_samples): - videos, base, res, flows = self.ema.ema_model.sample(lres) - pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) - - crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) - psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') - - videos_time_mean = videos.mean(dim = 1) - hres_time_mean = hres[:,2:,:,:,:].mean(dim = 1) - bias = videos_time_mean - hres_time_mean - norm = mpl.colors.Normalize(vmin = bias.min(), vmax = bias.max()) - sm = smap(norm, cmap) - b_c = [] - for l in range(num_videos_per_batch): - b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) - bias_color = np.stack(b_c, axis = 0) - - if not self.config.data_config.logscale: - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min - output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min - - else: - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin - output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin - - if self.config.data_config.logscale: - target = np.exp(target) + c384_gmin - 1e-14 - output = np.exp(output) + c384_gmin - 1e-14 - coarse = np.exp(coarse) + c48_gmin - 1e-14 - - nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) - diff_output = (output - nn_upscale).flatten() - diff_target = (target - nn_upscale).flatten() - vmin = min(diff_output.min(), diff_target.min()) - vmax = max(diff_output.max(), diff_target.max()) - bins = np.linspace(vmin, vmax, 100 + 1) - - fig, ax = plt.subplots(1, 1, figsize=(6, 4)) - ax.hist( - diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True - ) - ax.hist( - diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True - ) - ax.set_xlim(vmin, vmax) - ax.legend() - ax.set_ylabel("Density") - ax.set_yscale("log") - - output1 = output.flatten() - target1 = target.flatten() - vmin1 = min(output1.min(), target1.min()) - vmax1 = max(output1.max(), target1.max()) - bins1 = np.linspace(vmin1, vmax1, 100 + 1) - - fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) - ax1.hist( - output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True - ) - ax1.hist( - target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True - ) - ax1.set_xlim(vmin1, vmax1) - ax1.legend() - ax1.set_ylabel("Density") - ax1.set_yscale("log") - - flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) - for m in range(num_frames): - flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) - - if self.config.data_config.logscale: - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) - - else: - accelerator.log({"true_high": wandb.Video((hres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,:,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - target = np.log(target - c384_gmin + 1e-14) - output = np.log(output - c384_gmin + 1e-14) - coarse = np.log(coarse - c48_gmin + 1e-14) - target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) - output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) - coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) - accelerator.log({"true_loghigh": wandb.Video((np.repeat(target, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - accelerator.log({"logsamples": wandb.Video((np.repeat(output, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - - accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) - accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) - accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) - accelerator.log({"psnr": psnr_index.mean()}, step=self.step) - accelerator.log({"crps": crps_index}, step=self.step) - - milestone = self.step // self.save_and_sample_every - - self.save(milestone) - - pbar.update(1) - - accelerator.print('training complete') - - def sample(self): - - accelerator = self.accelerator - device = accelerator.device - - self.ema.ema_model.eval() - - c384_norm= torch.from_numpy(np.load("data/only_precip/c384_norm.npy")) - c48_norm = torch.from_numpy(np.load("data/only_precip/c48_norm.npy")) - - with torch.no_grad(): - - for tile in range(6): - - st = 0 - en = 27 - count = 0 - - while en < c48_norm.shape[1]: - - print(tile, st) - - lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) - hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) - - videos, base, res, flows = self.ema.ema_model.sample(lres) - - torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) - torch.save(hres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) - torch.save(lres[:,2:,:,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) - - st += 25 - en += 25 - count += 1 \ No newline at end of file diff --git a/projects/super_res/model/network_swinir.py b/projects/super_res/model/network_swinir.py index 461fb354ce..8c5f8537c0 100644 --- a/projects/super_res/model/network_swinir.py +++ b/projects/super_res/model/network_swinir.py @@ -643,7 +643,7 @@ class SwinIR(nn.Module): resi_connection: The convolutional block before residual connection. '1conv'/'3conv' """ - def __init__(self, img_size=64, patch_size=1, in_chans=3, + def __init__(self, img_size=64, patch_size=1, in_chans=3, out_chans=3, embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, @@ -652,7 +652,7 @@ def __init__(self, img_size=64, patch_size=1, in_chans=3, **kwargs): super(SwinIR, self).__init__() num_in_ch = in_chans - num_out_ch = in_chans + num_out_ch = out_chans num_feat = 64 self.img_range = img_range if in_chans == 3: @@ -666,6 +666,7 @@ def __init__(self, img_size=64, patch_size=1, in_chans=3, ##################################################################################################### ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) ##################################################################################################### diff --git a/projects/super_res/sampler.py b/projects/super_res/sampler.py index 08ecb72a63..dd5a9eeb0b 100644 --- a/projects/super_res/sampler.py +++ b/projects/super_res/sampler.py @@ -1,56 +1,80 @@ import os from model.autoreg_diffusion_mod import Unet, Flow, GaussianDiffusion, Trainer -from data.load_data import load_data from config_infer import config -model = Unet( - dim = config.dim, - channels = 2 * config.data_config["img_channel"], - out_dim = config.data_config["img_channel"], - dim_mults = config.dim_mults, - learned_sinusoidal_cond = config.learned_sinusoidal_cond, - random_fourier_features = config.random_fourier_features, - learned_sinusoidal_dim = config.learned_sinusoidal_dim -).cuda() - -flow = Flow( - dim = config.dim, - channels = 3 * config.data_config["img_channel"], - out_dim = 3, - dim_mults = config.dim_mults -).cuda() - -diffusion = GaussianDiffusion( - model, - flow, - image_size = config.data_config["img_size"], - timesteps = config.diffusion_steps, - sampling_timesteps = config.sampling_steps, - loss_type = config.loss, - objective = config.objective -).cuda() - -trainer = Trainer( - diffusion, - None, - None, - train_batch_size = config.batch_size, - train_lr = config.lr, - train_num_steps = config.steps, - gradient_accumulate_every = config.grad_acc, - val_num_of_batch = config.val_num_of_batch, - save_and_sample_every = config.save_and_sample_every, - ema_decay = config.ema_decay, - amp = config.amp, - split_batches = config.split_batches, - #eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), - eval_folder = os.path.join(config.eval_folder, f"{config.data_name}/"), - results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), - config = config - #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), -) - -trainer.load(config.milestone) - -trainer.sample() \ No newline at end of file +def main(): + + if config.data_config["multi"]: + + in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + + model = Unet( + dim = config.dim, + channels = in_ch_model, + out_dim = config.data_config["img_channel"], + dim_mults = config.dim_mults, + learned_sinusoidal_cond = config.learned_sinusoidal_cond, + random_fourier_features = config.random_fourier_features, + learned_sinusoidal_dim = config.learned_sinusoidal_dim + ).cuda() + + flow = Flow( + dim = config.dim, + channels = in_ch_flow, + out_dim = out_ch_flow, + dim_mults = config.dim_mults + ).cuda() + + diffusion = GaussianDiffusion( + model, + flow, + image_size = config.data_config["img_size"], + in_ch = in_ch_isr, + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective + ).cuda() + + trainer = Trainer( + diffusion, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.data_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + ) + + trainer.load(config.milestone) + + trainer.sample() + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/trainer.py b/projects/super_res/trainer.py index 617c257d95..29ad8358b1 100755 --- a/projects/super_res/trainer.py +++ b/projects/super_res/trainer.py @@ -5,9 +5,30 @@ from config import config def main(): + + if config.data_config["multi"]: + + in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + model = Unet( dim = config.dim, - channels = 2 * config.data_config["img_channel"], + channels = in_ch_model, out_dim = config.data_config["img_channel"], dim_mults = config.dim_mults, learned_sinusoidal_cond = config.learned_sinusoidal_cond, @@ -17,15 +38,16 @@ def main(): flow = Flow( dim = config.dim, - channels = 3 * config.data_config["img_channel"], - out_dim = 3, + channels = in_ch_flow, + out_dim = out_ch_flow, dim_mults = config.dim_mults ).cuda() - + diffusion = GaussianDiffusion( model, flow, image_size = config.data_config["img_size"], + in_ch = in_ch_isr, timesteps = config.diffusion_steps, sampling_timesteps = config.sampling_steps, loss_type = config.loss, @@ -55,7 +77,6 @@ def main(): eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), config = config - #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), ) trainer.train() @@ -63,4 +84,4 @@ def main(): if __name__ == "__main__": print(config) - main() + main() \ No newline at end of file diff --git a/projects/super_res/trainer_mod_flow.py b/projects/super_res/trainer_mod_flow.py deleted file mode 100755 index ec1971e97b..0000000000 --- a/projects/super_res/trainer_mod_flow.py +++ /dev/null @@ -1,66 +0,0 @@ -import os - -from model.autoreg_diffusion_mod_flow import Unet, Flow, GaussianDiffusion, Trainer -from data.load_data import load_data -from config_mod_flow import config - -def main(): - model = Unet( - dim = config.dim, - channels = 2 * config.data_config["img_channel"], - out_dim = config.data_config["img_channel"], - dim_mults = config.dim_mults, - learned_sinusoidal_cond = config.learned_sinusoidal_cond, - random_fourier_features = config.random_fourier_features, - learned_sinusoidal_dim = config.learned_sinusoidal_dim - ).cuda() - - flow = Flow( - dim = config.dim, - channels = 3 * config.data_config["img_channel"], - out_dim = 2, - dim_mults = config.dim_mults - ).cuda() - - diffusion = GaussianDiffusion( - model, - flow, - image_size = config.data_config["img_size"], - timesteps = config.diffusion_steps, - sampling_timesteps = config.sampling_steps, - loss_type = config.loss, - objective = config.objective - ).cuda() - - train_dl, val_dl = load_data( - config.data_config, - config.batch_size, - pin_memory = True, - num_workers = 4, - ) - - trainer = Trainer( - diffusion, - train_dl, - val_dl, - train_batch_size = config.batch_size, - train_lr = config.lr, - train_num_steps = config.steps, - gradient_accumulate_every = config.grad_acc, - val_num_of_batch = config.val_num_of_batch, - save_and_sample_every = config.save_and_sample_every, - ema_decay = config.ema_decay, - amp = config.amp, - split_batches = config.split_batches, - eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), - results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), - config = config - #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), - ) - - trainer.train() - - -if __name__ == "__main__": - print(config) - main() From e504606d678223e8bd90e94ae368ce52539f1812 Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Fri, 11 Aug 2023 23:26:49 +0000 Subject: [PATCH 5/9] fixing default config params --- projects/super_res/config.py | 4 ++-- projects/super_res/config_infer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/super_res/config.py b/projects/super_res/config.py index a3328f2125..10768bfdca 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -2,8 +2,8 @@ config = config_dict.ConfigDict() -config.dim = 128 -config.dim_mults = (1, 2, 2, 2, 4, 4) +config.dim = 64 +config.dim_mults = (1, 1, 2, 2, 3, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index 849ddc8a4f..2dc7d26641 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -2,8 +2,8 @@ config = config_dict.ConfigDict() -config.dim = 128 -config.dim_mults = (1, 2, 2, 2, 4, 4) +config.dim = 64 +config.dim_mults = (1, 1, 2, 2, 3, 4) config.learned_sinusoidal_cond = True, config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 From ce63134c9a663a51b8b05d1a309ae3b845eae170 Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Sat, 12 Aug 2023 00:00:14 +0000 Subject: [PATCH 6/9] scripts to get data from bucket --- projects/super_res/data/channel_data_gen.py | 29 +++++ projects/super_res/data/precip_data_gen.py | 50 +++++++++ projects/super_res/data/topo_data_gen.py | 112 ++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 projects/super_res/data/channel_data_gen.py create mode 100644 projects/super_res/data/precip_data_gen.py create mode 100644 projects/super_res/data/topo_data_gen.py diff --git a/projects/super_res/data/channel_data_gen.py b/projects/super_res/data/channel_data_gen.py new file mode 100644 index 0000000000..0858e6b382 --- /dev/null +++ b/projects/super_res/data/channel_data_gen.py @@ -0,0 +1,29 @@ +import xarray as xr +import numpy as np +from pathlib import Path + +channel_folder = Path('./more_channels') +channel_folder.mkdir(exist_ok = True, parents = True) + +c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) +c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + +channels = ["UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse"] +c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +split = int(c384_np.shape[1] * 0.8) + +# compute statistics on training set +c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(axis=(0,1,3,4)).reshape(1,1,4,1,1), c384_np[:, :split, :, :, :].max(axis=(0,1,3,4)).reshape(1,1,4,1,1), c48_np[:, :split, :, :, :].min(axis=(0,1,3,4)).reshape(1,1,4,1,1), c48_np[:, :split, :, :, :].max(axis=(0,1,3,4)).reshape(1,1,4,1,1) + +# normalize +c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +np.save('c384_min.npy', c384_min) +np.save('c384_max.npy', c384_max) +np.save('c48_min.npy', c48_min) +np.save('c48_max.npy', c48_max) +np.save('c48_norm.npy', c48_norm) +np.save('c384_norm.npy', c384_norm) \ No newline at end of file diff --git a/projects/super_res/data/precip_data_gen.py b/projects/super_res/data/precip_data_gen.py new file mode 100644 index 0000000000..2d8c1709a3 --- /dev/null +++ b/projects/super_res/data/precip_data_gen.py @@ -0,0 +1,50 @@ +import xarray as xr +import numpy as np +from pathlib import Path + +precip_folder = Path('./only_precip') +precip_folder.mkdir(exist_ok = True, parents = True) + +c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"}) +c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + +channels = ["PRATEsfc_coarse"] +c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +np.save('only_precip/c384_gmin.npy', c384_np.min()) +np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# calculate split (80/20) +split = int(c384_np.shape[1] * 0.8) + +# compute statistics on training set +c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# normalize +c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +np.save('only_precip/c384_min.npy', c384_min) +np.save('only_precip/c384_max.npy', c384_max) +np.save('only_precip/c48_min.npy', c48_min) +np.save('only_precip/c48_max.npy', c48_max) +np.save('only_precip/c48_norm.npy', c48_norm) +np.save('only_precip/c384_norm.npy', c384_norm) + +c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# compute statistics on training set +c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# normalize +c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +np.save('only_precip/c384_lgmin.npy', c384_lmin) +np.save('only_precip/c384_lgmax.npy', c384_lmax) +np.save('only_precip/c48_lgmin.npy', c48_lmin) +np.save('only_precip/c48_lgmax.npy', c48_lmax) +np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/topo_data_gen.py b/projects/super_res/data/topo_data_gen.py new file mode 100644 index 0000000000..531fd3fbfe --- /dev/null +++ b/projects/super_res/data/topo_data_gen.py @@ -0,0 +1,112 @@ +import xarray as xr +import numpy as np +from typing import TypeVar, Union, Tuple, Hashable, Any, Callable +from pathlib import Path + +topo_folder = Path('./topography') +topo_folder.mkdir(exist_ok = True, parents = True) + +topo384 = xr.open_zarr('gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_static_coarse.zarr') + +wts = xr.open_zarr('gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/grid_spec_coarse.zarr') + +T_DataArray_or_Dataset = TypeVar("T_DataArray_or_Dataset", xr.DataArray, xr.Dataset) +CoordFunc = Callable[[Any, Union[int, Tuple[int]]], Any] + +def coarsen_coords_coord_func( + coordinate: np.ndarray, axis: Union[int, Tuple[int]] = -1 +) -> np.ndarray: + """xarray coarsen coord_func version of coarsen_coords. + + Note that xarray requires an axis argument for this to work, but it is not + used by this function. To coarsen dimension coordinates, xarray reshapes + the 1D coordinate into a 2D array, with the rows representing groups of + values to aggregate together in some way. The length of the rows + corresponds to the coarsening factor. The value of the coordinate sampled + every coarsening factor is just the first value in each row. + + Args: + coordinate: 2D array of coordinate values + axis: Axes to reduce along (not used) + + Returns: + np.array + """ + return ( + ((coordinate[:, 0] - 1) // coordinate.shape[1] + 1) + .astype(int) + .astype(np.float32) + ) + +def _propagate_attrs( + reference_obj: T_DataArray_or_Dataset, obj: T_DataArray_or_Dataset +) -> T_DataArray_or_Dataset: + """Propagate attributes from the reference object to another. + + Args: + reference_obj: input object + obj: output object + + Returns: + xr.DataArray or xr.Dataset + """ + if isinstance(reference_obj, xr.Dataset): + for variable in reference_obj: + obj[variable].attrs = reference_obj[variable].attrs + obj.attrs = reference_obj.attrs + return obj + + +def weighted_block_average( + obj: T_DataArray_or_Dataset, + weights: xr.DataArray, + coarsening_factor: int, + x_dim: Hashable = "xaxis_1", + y_dim: Hashable = "yaxis_2", + coord_func: Union[str, CoordFunc] = coarsen_coords_coord_func, +) -> T_DataArray_or_Dataset: + """Coarsen a DataArray or Dataset through weighted block averaging. + + Note that this function assumes that the x and y dimension names of the + input DataArray and weights are the same. + + Args: + obj: Input Dataset or DataArray. + weights: Weights (e.g. area or pressure thickness). + coarsening_factor: Integer coarsening factor to use. + x_dim: x dimension name (default 'xaxis_1'). + y_dim: y dimension name (default 'yaxis_1'). + coord_func: function that is applied to the coordinates, or a + mapping from coordinate name to function. See `xarray's coarsen + method for details + `_. + + Returns: + xr.Dataset or xr.DataArray. + """ + coarsen_kwargs = {x_dim: coarsening_factor, y_dim: coarsening_factor} + numerator = (obj * weights).coarsen(coarsen_kwargs, coord_func=coord_func).sum() # type: ignore # noqa + denominator = weights.coarsen(coarsen_kwargs, coord_func=coord_func).sum() # type: ignore # noqa + result = numerator / denominator + + if isinstance(obj, xr.DataArray): + result = result.rename(obj.name) + + return _propagate_attrs(obj, result) + +topo48 = weighted_block_average(topo384, wts['area_coarse'], 8, 'grid_xt_coarse', 'grid_yt_coarse') + +topo384 = topo384['zsurf_coarse'].values +topo48 = topo48['zsurf_coarse'].values + +topo384_min, topo384_max, topo48_min, topo48_max = topo384.min(), topo384.max(), topo48.min(), topo48.max() + +topo384_norm = (topo384 - topo384_min) / (topo384_max - topo384_min) +topo48_norm = (topo48 - topo48_min) / (topo48_max - topo48_min) + +np.save('topography/topo384_norm.npy', topo384_norm) +np.save('topography/topo48_norm.npy', topo48_norm) +np.save('topography/topo384_min.npy', topo384_min) +np.save('topography/topo384_max.npy', topo384_max) +np.save('topography/topo48_min.npy', topo48_min) +np.save('topography/topo48_max.npy', topo48_max) \ No newline at end of file From 134c0033b2376a5b3da000cfdff37800dabfe4f6 Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Tue, 22 Aug 2023 22:33:04 +0000 Subject: [PATCH 7/9] fixing bugs after refactor --- projects/super_res/config_infer.py | 14 ++--- projects/super_res/data/channel_data_gen.py | 12 ++-- .../super_res/model/autoreg_diffusion_mod.py | 62 ++++++++++++------- 3 files changed, 52 insertions(+), 36 deletions(-) diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index 2dc7d26641..e0a19aa5d6 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -8,26 +8,26 @@ config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 20 +config.sampling_steps = 10 config.loss = "l2" config.objective = "pred_v" config.lr = 8e-5 config.steps = 5000000 config.grad_acc = 1 -config.val_num_of_batch = 2 +config.val_num_of_batch = 1 config.save_and_sample_every = 5000 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "multichannel_minipatch" +config.additional_note = "" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" config.milestone = 1 -config.rollout = "partial" +config.rollout = "full" config.rollout_batch = 25 -config.batch_size = 1 +config.batch_size = 2 config.data_config = config_dict.ConfigDict({ "dataset_name": "c384", "length": 7, @@ -35,8 +35,8 @@ "img_channel": 1, "img_size": 384, "logscale": True, - "multi": True, - "flow": "2d", + "multi": False, + "flow": "3d", "minipatch": False }) diff --git a/projects/super_res/data/channel_data_gen.py b/projects/super_res/data/channel_data_gen.py index 0858e6b382..145fc634eb 100644 --- a/projects/super_res/data/channel_data_gen.py +++ b/projects/super_res/data/channel_data_gen.py @@ -21,9 +21,9 @@ c384_norm= (c384_np - c384_min) / (c384_max - c384_min) c48_norm = (c48_np - c48_min) / (c48_max - c48_min) -np.save('c384_min.npy', c384_min) -np.save('c384_max.npy', c384_max) -np.save('c48_min.npy', c48_min) -np.save('c48_max.npy', c48_max) -np.save('c48_norm.npy', c48_norm) -np.save('c384_norm.npy', c384_norm) \ No newline at end of file +np.save('more_channels/c384_min.npy', c384_min) +np.save('more_channels/c384_max.npy', c384_max) +np.save('more_channels/c48_min.npy', c48_min) +np.save('more_channels/c48_max.npy', c48_max) +np.save('more_channels/c48_norm.npy', c48_norm) +np.save('more_channels/c384_norm.npy', c384_norm) \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index a0ee69f27a..886f36dbe8 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -197,7 +197,7 @@ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corn torch.arange(0, w, dtype=x.dtype, device=x.device)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False - + vgrid = grid + flow # scale grid to [-1,1] @@ -993,7 +993,8 @@ def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): elif flow_mode == '2d': - warped = flow_warp(ures_flow, flow) + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) sres = warped + res @@ -1003,8 +1004,13 @@ def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): res = rearrange(res, '(b t) c h w -> b t c h w', b = b) flow = rearrange(flow, '(b t) c h w -> b t c h w', b = b) - return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) - + if flow_mode == '2d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow + + elif flow_mode == '3d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) @torch.no_grad() def interpolate(self, x1, x2, t = None, lam = 0.5): @@ -1057,7 +1063,8 @@ def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, no elif flow_mode == '2d': - warped = flow_warp(ures_flow, flow) + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') x_start = x_start - warped @@ -1210,21 +1217,23 @@ def __init__( mixed_precision = 'fp16' if fp16 else 'no', log_with = 'wandb', ) - self.accelerator.init_trackers("vsr-orig-autoreg-hres", - init_kwargs={ - "wandb": { - "notes": "Use VSR to improve precipitation forecasting.", - # Change "name" to set the name of the run. - "name": None, - } - }, - ) + # self.accelerator.init_trackers("vsr-orig-autoreg-hres", + # init_kwargs={ + # "wandb": { + # "notes": "Use VSR to improve precipitation forecasting.", + # # Change "name" to set the name of the run. + # "name": None, + # } + # }, + # ) self.config = config self.accelerator.native_amp = amp self.multi = config.data_config["multi"] self.rollout = config.rollout + self.rollout_batch = config.rollout_batch self.flow = config.data_config["flow"] self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] self.model = diffusion_model @@ -1263,7 +1272,7 @@ def __init__( self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) self.train_dl = cycle(train_dl) - self.val_dl = cycle(val_dl) + self.val_dl = val_dl def save(self, milestone): if not self.accelerator.is_local_main_process: @@ -1409,7 +1418,7 @@ def train(self): b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) bias_color = np.stack(b_c, axis = 0) - if not self.config.data_config.logscale: + if not self.logscale: target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min @@ -1419,7 +1428,7 @@ def train(self): output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin - if self.config.data_config.logscale: + if self.logscale: target = np.exp(target) + c384_gmin - 1e-14 output = np.exp(output) + c384_gmin - 1e-14 coarse = np.exp(coarse) + c48_gmin - 1e-14 @@ -1474,7 +1483,7 @@ def train(self): flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) - if self.config.data_config.logscale: + if self.logscale: accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) @@ -1522,8 +1531,15 @@ def sample(self): self.ema.ema_model.eval() - c384_norm= torch.from_numpy(np.load("data/only_precip/c384_lgnorm.npy")) - c48_norm = torch.from_numpy(np.load("data/only_precip/c48_lgnorm.npy")) + if self.logscale: + + c384_norm= torch.from_numpy(np.load("data/only_precip/c384_lgnorm.npy")) + c48_norm = torch.from_numpy(np.load("data/only_precip/c48_lgnorm.npy")) + + else: + + c384_norm= torch.from_numpy(np.load("data/only_precip/c384_norm.npy")) + c48_norm = torch.from_numpy(np.load("data/only_precip/c48_norm.npy")) if self.multi: @@ -1553,7 +1569,7 @@ def sample(self): lres = c48_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) hres = c384_norm[tile,st:en,:,:,:].unsqueeze(0).to(device) - videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi) + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) @@ -1566,7 +1582,7 @@ def sample(self): if self.rollout == 'partial': seq_len = self.rollout_batch - indices = get_random_idx_with_difference(0, c48_norm.shape[1] - (seq_len + 2), 250 // seq_len, seq_len + 2) # 250 samples per tile + indices = get_random_idx_with_difference(0, c48_norm.shape[1] - (seq_len + 2), 75 // seq_len, seq_len + 2) # 250 samples per tile for count, st in enumerate(indices): @@ -1575,7 +1591,7 @@ def sample(self): lres = c48_norm[tile,st:st+(seq_len+2),:,:,:].unsqueeze(0).to(device) hres = c384_norm[tile,st:st+(seq_len+2),:,:,:].unsqueeze(0).to(device) - videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi) + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) From acec524a7d11410be677c332a2badb00a9cce343 Mon Sep 17 00:00:00 2001 From: Prakhar Srivastava Date: Tue, 29 Aug 2023 20:42:05 +0000 Subject: [PATCH 8/9] xarray dataloaders, ensemble dataloading, schedulers, histogram matching metrics --- projects/super_res/config.py | 12 +- projects/super_res/config_infer.py | 12 +- projects/super_res/data/dataload.sh | 13 + .../data/ensemblec384logtrainstats.py | 63 +++ .../data/ensemblec384topotrainstats.py | 81 ++++ .../super_res/data/ensemblec384trainstats.py | 87 +++++ .../data/ensemblec48logtrainstats.py | 63 +++ .../super_res/data/ensemblec48trainstats.py | 102 +++++ projects/super_res/data/load_data.py | 2 +- projects/super_res/data/load_dataset.py | 4 +- projects/super_res/data/vsrdata_ensemble.py | 142 +++++++ projects/super_res/data/vsrdata_new.py | 107 ++++++ .../super_res/model/autoreg_diffusion_mod.py | 360 +++++++++++------- projects/super_res/trainer.py | 9 +- 14 files changed, 895 insertions(+), 162 deletions(-) create mode 100755 projects/super_res/data/dataload.sh create mode 100644 projects/super_res/data/ensemblec384logtrainstats.py create mode 100644 projects/super_res/data/ensemblec384topotrainstats.py create mode 100644 projects/super_res/data/ensemblec384trainstats.py create mode 100644 projects/super_res/data/ensemblec48logtrainstats.py create mode 100644 projects/super_res/data/ensemblec48trainstats.py create mode 100644 projects/super_res/data/vsrdata_ensemble.py create mode 100644 projects/super_res/data/vsrdata_new.py diff --git a/projects/super_res/config.py b/projects/super_res/config.py index 10768bfdca..f61eb5ea23 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -11,15 +11,15 @@ config.sampling_steps = 20 config.loss = "l2" config.objective = "pred_v" -config.lr = 8e-5 -config.steps = 5000000 +config.lr = 1e-4 +config.steps = 700000 config.grad_acc = 1 -config.val_num_of_batch = 2 -config.save_and_sample_every = 5000 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "multichannel_minipatch" +config.additional_note = "2d-nomulti-ls-ensemble" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" @@ -35,7 +35,7 @@ "img_channel": 1, "img_size": 384, "logscale": True, - "multi": True, + "multi": False, "flow": "2d", "minipatch": False }) diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index e0a19aa5d6..ce6305a97e 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -14,12 +14,12 @@ config.lr = 8e-5 config.steps = 5000000 config.grad_acc = 1 -config.val_num_of_batch = 1 -config.save_and_sample_every = 5000 +config.val_num_of_batch = 5 +config.save_and_sample_every = 50 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "" +config.additional_note = "2d_multi_nols" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" @@ -34,9 +34,9 @@ "channels": ["PRATEsfc_coarse"], "img_channel": 1, "img_size": 384, - "logscale": True, - "multi": False, - "flow": "3d", + "logscale": False, + "multi": True, + "flow": "2d", "minipatch": False }) diff --git a/projects/super_res/data/dataload.sh b/projects/super_res/data/dataload.sh new file mode 100755 index 0000000000..55f69df969 --- /dev/null +++ b/projects/super_res/data/dataload.sh @@ -0,0 +1,13 @@ +#! /bin/sh +channel='c48_atmos_ave' +file='atmos_8xdaily_ave_coarse.zarr' +for member in $(seq -f "%04g" 1 11) +do + mkdir -p /data/prakhars/ensemble/$channel/$member + gsutil -m cp -r gs://vcm-ml-raw-flexible-retention/2023-08-14-C384-reference-ensemble/ic_$member/diagnostics/$file /data/prakhars/ensemble/$channel/$member +done +# channel --> file +# c384_precip_ave --> sfc_8xdaily_ave.zarr +# c48_precip_plus_more_ave --> sfc_8xdaily_ave_coarse.zarr +# c384_topo --> atmos_static.zarr +# c48_atmos_ave --> atmos_8xdaily_ave_coarse.zarr \ No newline at end of file diff --git a/projects/super_res/data/ensemblec384logtrainstats.py b/projects/super_res/data/ensemblec384logtrainstats.py new file mode 100644 index 0000000000..2b18ae0a96 --- /dev/null +++ b/projects/super_res/data/ensemblec384logtrainstats.py @@ -0,0 +1,63 @@ +import pickle +import numpy as np +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +# load the data +with open('ensemble_c384_trainstats/chl.pkl', 'rb') as f: + chl = pickle.load(f) + +precip = chl['PRATEsfc'] +log_chl = {} +log_chl['PRATEsfc'] = {} +log_chl['PRATEsfc']['min'] = np.log(precip['min'] - precip['min'] + 1e-14) +log_chl['PRATEsfc']['max'] = np.log(precip['max'] - precip['min'] + 1e-14) + +# save the chl dictionary as pickle +with open(precip_folder / 'log_chl.pkl', 'wb') as f: + pickle.dump(log_chl, f) + + + +# channels = ["PRATEsfc_coarse"] +# c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +# c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +# np.save('only_precip/c384_gmin.npy', c384_np.min()) +# np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# # calculate split (80/20) +# split = int(c384_np.shape[1] * 0.8) + +# # compute statistics on training set +# c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# # normalize +# c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +# c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +# np.save('only_precip/c384_min.npy', c384_min) +# np.save('only_precip/c384_max.npy', c384_max) +# np.save('only_precip/c48_min.npy', c48_min) +# np.save('only_precip/c48_max.npy', c48_max) +# np.save('only_precip/c48_norm.npy', c48_norm) +# np.save('only_precip/c384_norm.npy', c384_norm) + +# c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +# c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# # compute statistics on training set +# c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# # normalize +# c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +# c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +# np.save('only_precip/c384_lgmin.npy', c384_lmin) +# np.save('only_precip/c384_lgmax.npy', c384_lmax) +# np.save('only_precip/c48_lgmin.npy', c48_lmin) +# np.save('only_precip/c48_lgmax.npy', c48_lmax) +# np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +# np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec384topotrainstats.py b/projects/super_res/data/ensemblec384topotrainstats.py new file mode 100644 index 0000000000..461b376e5f --- /dev/null +++ b/projects/super_res/data/ensemblec384topotrainstats.py @@ -0,0 +1,81 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["zsurf"] +chl = {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + topo = xr.open_zarr(f"/data/prakhars/ensemble/c384_topo/{member:04d}/atmos_static.zarr") + + for channel in tqdm(channels): + channel_384 = topo[channel] + channel_384_min = channel_384.min().values + channel_384_max = channel_384.max().values + if channel_384_min < chl[channel]['min']: + chl[channel]['min'] = channel_384_min + if channel_384_max > chl[channel]['max']: + chl[channel]['max'] = channel_384_max + +# save the chl dictionary as pickle +with open(precip_folder / 'topo.pkl', 'wb') as f: + pickle.dump(chl, f) + + + + + +# channels = ["PRATEsfc_coarse"] +# c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +# c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +# np.save('only_precip/c384_gmin.npy', c384_np.min()) +# np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# # calculate split (80/20) +# split = int(c384_np.shape[1] * 0.8) + +# # compute statistics on training set +# c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# # normalize +# c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +# c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +# np.save('only_precip/c384_min.npy', c384_min) +# np.save('only_precip/c384_max.npy', c384_max) +# np.save('only_precip/c48_min.npy', c48_min) +# np.save('only_precip/c48_max.npy', c48_max) +# np.save('only_precip/c48_norm.npy', c48_norm) +# np.save('only_precip/c384_norm.npy', c384_norm) + +# c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +# c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# # compute statistics on training set +# c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# # normalize +# c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +# c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +# np.save('only_precip/c384_lgmin.npy', c384_lmin) +# np.save('only_precip/c384_lgmax.npy', c384_lmax) +# np.save('only_precip/c48_lgmin.npy', c48_lmin) +# np.save('only_precip/c48_lgmax.npy', c48_lmax) +# np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +# np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec384trainstats.py b/projects/super_res/data/ensemblec384trainstats.py new file mode 100644 index 0000000000..95d06743fc --- /dev/null +++ b/projects/super_res/data/ensemblec384trainstats.py @@ -0,0 +1,87 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c384_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["PRATEsfc"] +chl = {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + c384 = xr.open_zarr(f"/data/prakhars/ensemble/c384_precip_ave/{member:04d}/sfc_8xdaily_ave.zarr") + + for channel in tqdm(channels): + + channel_384 = c384[channel] + + for idx in tqdm(range(397)): + + channel_384_slice = channel_384.isel(time = slice(idx*8, (idx+1)*8)) + channel_384_max = channel_384_slice.max().values + channel_384_min = channel_384_slice.min().values + + if channel_384_min < chl[channel]['min']: + + chl[channel]['min'] = channel_384_min + + if channel_384_max > chl[channel]['max']: + + chl[channel]['max'] = channel_384_max + +# save the chl dictionary as pickle +with open(precip_folder / 'chl.pkl', 'wb') as f: + pickle.dump(chl, f) + + +# channels = ["PRATEsfc_coarse"] +# c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +# c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +# np.save('only_precip/c384_gmin.npy', c384_np.min()) +# np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# # calculate split (80/20) +# split = int(c384_np.shape[1] * 0.8) + +# # compute statistics on training set +# c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# # normalize +# c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +# c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +# np.save('only_precip/c384_min.npy', c384_min) +# np.save('only_precip/c384_max.npy', c384_max) +# np.save('only_precip/c48_min.npy', c48_min) +# np.save('only_precip/c48_max.npy', c48_max) +# np.save('only_precip/c48_norm.npy', c48_norm) +# np.save('only_precip/c384_norm.npy', c384_norm) + +# c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +# c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# # compute statistics on training set +# c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# # normalize +# c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +# c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +# np.save('only_precip/c384_lgmin.npy', c384_lmin) +# np.save('only_precip/c384_lgmax.npy', c384_lmax) +# np.save('only_precip/c48_lgmin.npy', c48_lmin) +# np.save('only_precip/c48_lgmax.npy', c48_lmax) +# np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +# np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec48logtrainstats.py b/projects/super_res/data/ensemblec48logtrainstats.py new file mode 100644 index 0000000000..85fb04ad00 --- /dev/null +++ b/projects/super_res/data/ensemblec48logtrainstats.py @@ -0,0 +1,63 @@ +import pickle +import numpy as np +from pathlib import Path + +precip_folder = Path('./ensemble_c48_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +# load the data +with open('ensemble_c48_trainstats/chl.pkl', 'rb') as f: + chl = pickle.load(f) + +precip = chl['PRATEsfc_coarse'] +log_chl = {} +log_chl['PRATEsfc_coarse'] = {} +log_chl['PRATEsfc_coarse']['min'] = np.log(precip['min'] - precip['min'] + 1e-14) +log_chl['PRATEsfc_coarse']['max'] = np.log(precip['max'] - precip['min'] + 1e-14) + +# save the chl dictionary as pickle +with open(precip_folder / 'log_chl.pkl', 'wb') as f: + pickle.dump(log_chl, f) + + + +# channels = ["PRATEsfc_coarse"] +# c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +# c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +# np.save('only_precip/c384_gmin.npy', c384_np.min()) +# np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# # calculate split (80/20) +# split = int(c384_np.shape[1] * 0.8) + +# # compute statistics on training set +# c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# # normalize +# c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +# c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +# np.save('only_precip/c384_min.npy', c384_min) +# np.save('only_precip/c384_max.npy', c384_max) +# np.save('only_precip/c48_min.npy', c48_min) +# np.save('only_precip/c48_max.npy', c48_max) +# np.save('only_precip/c48_norm.npy', c48_norm) +# np.save('only_precip/c384_norm.npy', c384_norm) + +# c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +# c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# # compute statistics on training set +# c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# # normalize +# c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +# c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +# np.save('only_precip/c384_lgmin.npy', c384_lmin) +# np.save('only_precip/c384_lgmax.npy', c384_lmax) +# np.save('only_precip/c48_lgmin.npy', c48_lmin) +# np.save('only_precip/c48_lgmax.npy', c48_lmax) +# np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +# np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/ensemblec48trainstats.py b/projects/super_res/data/ensemblec48trainstats.py new file mode 100644 index 0000000000..3dbb5c430a --- /dev/null +++ b/projects/super_res/data/ensemblec48trainstats.py @@ -0,0 +1,102 @@ +import pickle +import numpy as np +import xarray as xr +from tqdm import tqdm +from pathlib import Path + +precip_folder = Path('./ensemble_c48_trainstats') +precip_folder.mkdir(exist_ok = True, parents = True) + +ENSEMBLE = 10 + +channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] +atm_channels = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + +chl, atm_chl = {}, {} + +for channel in channels: + + chl[channel] = {} + chl[channel]['min'] = np.PINF + chl[channel]['max'] = np.NINF + +for channel in atm_channels: + + atm_chl[channel] = {} + atm_chl[channel]['min'] = np.PINF + atm_chl[channel]['max'] = np.NINF + +for member in tqdm(range(1, ENSEMBLE + 1)): + + c48 = xr.open_zarr(f"/data/prakhars/ensemble/c48_precip_plus_more_ave/{member:04d}/sfc_8xdaily_ave_coarse.zarr") + c48_atm = xr.open_zarr(f"/data/prakhars/ensemble/c48_atmos_ave/{member:04d}/atmos_8xdaily_ave_coarse.zarr") + + for channel in tqdm(channels): + channel_48 = c48[channel] + channel_48_min = channel_48.min().values + channel_48_max = channel_48.max().values + if channel_48_min < chl[channel]['min']: + chl[channel]['min'] = channel_48_min + if channel_48_max > chl[channel]['max']: + chl[channel]['max'] = channel_48_max + + for channel in tqdm(atm_channels): + channel_48 = c48_atm[channel] + channel_48_min = channel_48.min().values + channel_48_max = channel_48.max().values + if channel_48_min < atm_chl[channel]['min']: + atm_chl[channel]['min'] = channel_48_min + if channel_48_max > atm_chl[channel]['max']: + atm_chl[channel]['max'] = channel_48_max + +# save the chl dictionary as pickle +with open(precip_folder / 'chl.pkl', 'wb') as f: + pickle.dump(chl, f) + +# save the atm_chl dictionary as pickle +with open(precip_folder / 'atm_chl.pkl', 'wb') as f: + pickle.dump(atm_chl, f) + + + + +# channels = ["PRATEsfc_coarse"] +# c384_np = np.stack([c384[channel].values for channel in channels], axis = 2) +# c48_np = np.stack([c48[channel].values for channel in channels], axis = 2) + +# np.save('only_precip/c384_gmin.npy', c384_np.min()) +# np.save('only_precip/c48_gmin.npy', c48_np.min()) + +# # calculate split (80/20) +# split = int(c384_np.shape[1] * 0.8) + +# # compute statistics on training set +# c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max() + +# # normalize +# c384_norm= (c384_np - c384_min) / (c384_max - c384_min) +# c48_norm = (c48_np - c48_min) / (c48_max - c48_min) + +# np.save('only_precip/c384_min.npy', c384_min) +# np.save('only_precip/c384_max.npy', c384_max) +# np.save('only_precip/c48_min.npy', c48_min) +# np.save('only_precip/c48_max.npy', c48_max) +# np.save('only_precip/c48_norm.npy', c48_norm) +# np.save('only_precip/c384_norm.npy', c384_norm) + +# c384_lnp = np.log(c384_np - c384_np.min() + 1e-14) +# c48_lnp = np.log(c48_np - c48_np.min() + 1e-14) + +# # compute statistics on training set +# c384_lmin, c384_lmax, c48_lmin, c48_lmax = c384_lnp[:, :split, :, :, :].min(), c384_lnp[:, :split, :, :, :].max(), c48_lnp[:, :split, :, :, :].min(), c48_lnp[:, :split, :, :, :].max() + +# # normalize +# c384_lnorm= (c384_lnp - c384_lmin) / (c384_lmax - c384_lmin) +# c48_lnorm = (c48_lnp - c48_lmin) / (c48_lmax - c48_lmin) + +# np.save('only_precip/c384_lgmin.npy', c384_lmin) +# np.save('only_precip/c384_lgmax.npy', c384_lmax) +# np.save('only_precip/c48_lgmin.npy', c48_lmin) +# np.save('only_precip/c48_lgmax.npy', c48_lmax) +# np.save('only_precip/c48_lgnorm.npy', c48_lnorm) +# np.save('only_precip/c384_lgnorm.npy', c384_lnorm) \ No newline at end of file diff --git a/projects/super_res/data/load_data.py b/projects/super_res/data/load_data.py index dc6c6bc21f..cee9edff21 100644 --- a/projects/super_res/data/load_data.py +++ b/projects/super_res/data/load_data.py @@ -19,7 +19,7 @@ def load_data(data_config, batch_size, num_workers = 4, pin_memory = True): val = DataLoader( val, - batch_size = batch_size, + batch_size = 5, shuffle = False, num_workers = num_workers, pin_memory = pin_memory, diff --git a/projects/super_res/data/load_dataset.py b/projects/super_res/data/load_dataset.py index e3246963f9..c4c74cb10f 100644 --- a/projects/super_res/data/load_dataset.py +++ b/projects/super_res/data/load_dataset.py @@ -1,4 +1,6 @@ -from .vsrdata import VSRDataset +# from .vsrdata import VSRDataset +# from .vsrdata_new import VSRDataset +from .vsrdata_ensemble import VSRDataset def load_dataset(data_config): diff --git a/projects/super_res/data/vsrdata_ensemble.py b/projects/super_res/data/vsrdata_ensemble.py new file mode 100644 index 0000000000..b596ef86d0 --- /dev/null +++ b/projects/super_res/data/vsrdata_ensemble.py @@ -0,0 +1,142 @@ +import pickle +import numpy as np +import xarray as xr +from torch.utils.data import Dataset + +class VSRDataset(Dataset): + + def __init__(self, mode, length, logscale = False, multi = False): + ''' + Args: + channels (list): list of channels to use + mode (str): train or val + length (int): length of sequence + logscale (bool): whether to logscale the data + multi (bool): whether to use multi-channel data + ''' + + ENSEMBLE = 11 + + # load data + self.X, self.X_, self.y, self.topo = {}, {}, {}, {} + + for member in range(1, ENSEMBLE + 1): + + self.X[member] = xr.open_zarr(f"/data/prakhars/ensemble/c48_precip_plus_more_ave/{member:04d}/sfc_8xdaily_ave_coarse.zarr") + self.X_[member] = xr.open_zarr(f"/data/prakhars/ensemble/c48_atmos_ave/{member:04d}/atmos_8xdaily_ave_coarse.zarr") + self.y[member] = xr.open_zarr(f"/data/prakhars/ensemble/c384_precip_ave/{member:04d}/sfc_8xdaily_ave.zarr") + self.topo[member] = xr.open_zarr(f"/data/prakhars/ensemble/c384_topo/{member:04d}/atmos_static.zarr") + + # expected sequence length + self.length = length + + self.mode = mode + self.logscale = logscale + self.multi = multi + + self.time_steps = self.X[1].time.shape[0] + self.tiles = self.X[1].tile.shape[0] + + # load statistics + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + self.c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + self.c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + self.c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + self.c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + self.c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + self.c384_topo = pickle.load(f) + + if multi: + + self.c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + self.c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + self.c384_channels = ["PRATEsfc"] + + else: + + self.c48_channels = ["PRATEsfc_coarse"] + self.c384_channels = ["PRATEsfc"] + + self.indices = list(range(self.time_steps - self.length + 1)) + + def __len__(self): + + return len(self.indices) + + def __getitem__(self, idx): + + time_idx = self.indices[idx] + + if self.mode == 'train': + + tile = idx % self.tiles + member = idx % 10 + 1 + + else: + + tile = idx % self.tiles + member = 11 + + X = self.X[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + X_ = self.X_[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + y = self.y[member].isel(time = slice(time_idx, time_idx + self.length), tile = tile) + + if self.multi: + + X = np.stack([X[channel].values for channel in self.c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in self.c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in self.c384_channels], axis = 1) + topo = self.topo[member].isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), self.length, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in self.c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in self.c384_channels], axis = 1) + + + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - self.c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - self.c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - self.c48_log_chl["PRATEsfc_coarse"]['min']) / (self.c48_log_chl["PRATEsfc_coarse"]['max'] - self.c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - self.c384_log_chl["PRATEsfc"]['min']) / (self.c384_log_chl["PRATEsfc"]['max'] - self.c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - self.c48_chl["PRATEsfc_coarse"]['min']) / (self.c48_chl["PRATEsfc_coarse"]['max'] - self.c48_chl["PRATEsfc_coarse"]['min']) + y = (y - self.c384_chl["PRATEsfc"]['min']) / (self.c384_chl["PRATEsfc"]['max'] - self.c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - self.c48_chl[self.c48_channels[i]]['min']) / (self.c48_chl[self.c48_channels[i]]['max'] - self.c48_chl[self.c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - self.c48_atm_chl[self.c48_channels_atmos[i]]['min']) / (self.c48_atm_chl[self.c48_channels_atmos[i]]['max'] - self.c48_atm_chl[self.c48_channels_atmos[i]]['min']) + + topo = (topo - self.c384_topo["zsurf"]['min']) / (self.c384_topo["zsurf"]['max'] - self.c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + return {'LR' : X, 'HR' : y} \ No newline at end of file diff --git a/projects/super_res/data/vsrdata_new.py b/projects/super_res/data/vsrdata_new.py new file mode 100644 index 0000000000..869b87c39e --- /dev/null +++ b/projects/super_res/data/vsrdata_new.py @@ -0,0 +1,107 @@ +import numpy as np +import xarray as xr +from torch.utils.data import Dataset + +class VSRDataset(Dataset): + + def __init__(self, mode, length, logscale = False, multi = False): + ''' + Args: + channels (list): list of channels to use + mode (str): train or val + length (int): length of sequence + logscale (bool): whether to logscale the data + quick (bool): whether to load data from bucket or from local (local only supports single precipitation channel) + ''' + + # load data + self.y = xr.open_zarr("/data/prakhars/pire_atmos_phys_3h_c384.zarr") + self.X = xr.open_zarr('/data/prakhars/pire_atmos_phys_3h_c48.zarr') + + # expected sequence length + self.length = length + + # mode + self.mode = mode + + self.logscale = logscale + + if logscale: + + self.c384_gmin = np.load('data/only_precip/c384_gmin.npy') + self.c48_gmin = np.load('data/only_precip/c48_gmin.npy') + self.c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + self.c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + self.c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + self.c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + + else: + + self.c384_min = np.load('data/only_precip/c384_min.npy') + self.c384_max = np.load('data/only_precip/c384_max.npy') + self.c48_min = np.load('data/only_precip/c48_min.npy') + self.c48_max = np.load('data/only_precip/c48_max.npy') + + self.time_steps = self.X.time.shape[0] + self.tiles = self.X.tile.shape[0] + + self.multi = multi + + if multi: + + self.channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse"] + self.topo384 = np.load("data/topography/topo384_norm.npy") + self.c384_multimin = np.load('data/more_channels/c384_min.npy') + self.c384_multimax = np.load('data/more_channels/c384_max.npy') + self.c48_multimin = np.load('data/more_channels/c48_min.npy') + self.c48_multimax = np.load('data/more_channels/c48_max.npy') + + else: + + self.channels = ["PRATEsfc_coarse"] + + if mode == 'train': + + self.indices = list(range(int(self.time_steps * 0.8) - self.length + 1)) + + elif mode == 'val': + + self.indices = list(range(int(self.time_steps * 0.8), self.time_steps - self.length + 1)) + + def __len__(self): + + return len(self.indices) + + def __getitem__(self, idx): + + time_idx = self.indices[idx] + if self.mode == 'train': + tile = idx % self.tiles + else: + tile = 0 + + lowres = self.X.isel(time = slice(time_idx, time_idx + self.length), tile = tile) + lowres = np.stack([lowres[channel].values for channel in self.channels], axis = 1) + highres = self.y.isel(time = slice(time_idx, time_idx + self.length), tile = tile) + highres = np.stack([highres[channel].values for channel in self.channels[0:1]], axis = 1) + + if self.logscale: + + lowres[:,0:1,:,:] = np.log(lowres[:,0:1,:,:] - self.c48_gmin + 1e-14) + highres = np.log(highres - self.c384_gmin + 1e-14) + lowres[:,0:1,:,:] = (lowres[:,0:1,:,:] - self.c48_lgmin) / (self.c48_lgmax - self.c48_lgmin) + highres = (highres - self.c384_lgmin) / (self.c384_lgmax - self.c384_lgmin) + + else: + + lowres[:,0:1,:,:] = (lowres[:,0:1,:,:] - self.c48_min) / (self.c48_max - self.c48_min) + highres = (highres - self.c384_min) / (self.c384_max - self.c384_min) + + if self.multi: + + lowres[:,1:,:,:] = (lowres[:,1:,:,:] - self.c48_multimin) / (self.c48_multimax - self.c48_multimin) + topo = self.topo384[tile,:,:] + topo = np.repeat(topo.reshape((1,1,384,384)), self.length, axis = 0) + highres = np.concatenate((highres, topo), axis = 1) + + return {'LR' : lowres, 'HR' : highres} \ No newline at end of file diff --git a/projects/super_res/model/autoreg_diffusion_mod.py b/projects/super_res/model/autoreg_diffusion_mod.py index 886f36dbe8..7004f873e3 100644 --- a/projects/super_res/model/autoreg_diffusion_mod.py +++ b/projects/super_res/model/autoreg_diffusion_mod.py @@ -15,9 +15,13 @@ from torchvision.transforms.functional import crop import piq +import pickle +import cv2 +from scipy.stats import wasserstein_distance from kornia import filters from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR from einops import rearrange, reduce from einops.layers.torch import Rearrange @@ -928,7 +932,7 @@ def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): b, f, c, h, w = lres.shape if multi: - + topo = hres[:, :, 1:2, :, :] low_chans = lres[:, :, 1:, :, :] topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) @@ -961,14 +965,6 @@ def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): r = torch.roll(l, -1, 1) ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') - - if multi: - - l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) - - else: - - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) m = lres.clone() m1 = rearrange(m, 'b t c h w -> (b t) c h w') @@ -995,6 +991,16 @@ def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): flow = self.unnormalize(flow) warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) sres = warped + res @@ -1071,11 +1077,13 @@ def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, no if multi: - l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) else: - l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped del f @@ -1217,15 +1225,15 @@ def __init__( mixed_precision = 'fp16' if fp16 else 'no', log_with = 'wandb', ) - # self.accelerator.init_trackers("vsr-orig-autoreg-hres", - # init_kwargs={ - # "wandb": { - # "notes": "Use VSR to improve precipitation forecasting.", - # # Change "name" to set the name of the run. - # "name": None, - # } - # }, - # ) + self.accelerator.init_trackers("vsr-orig-autoreg-hres", + init_kwargs={ + "wandb": { + "notes": "Use VSR to improve precipitation forecasting.", + # Change "name" to set the name of the run. + "name": None, + } + }, + ) self.config = config self.accelerator.native_amp = amp self.multi = config.data_config["multi"] @@ -1250,6 +1258,7 @@ def __init__( # optimizer self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 1e-7) # for logging results in a folder periodically @@ -1270,7 +1279,7 @@ def __init__( # prepare model, dataloader, optimizer with accelerator - self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) self.train_dl = cycle(train_dl) self.val_dl = val_dl @@ -1316,19 +1325,44 @@ def train(self): cmap = mpl.colormaps['RdBu_r'] fcmap = mpl.colormaps['gray_r'] - c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') - c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') - c384_gmin = np.load('data/only_precip/c384_gmin.npy') + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') - c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') - c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') - c48_gmin = np.load('data/only_precip/c48_gmin.npy') + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') - c384_min = np.load('data/only_precip/c384_min.npy') - c384_max = np.load('data/only_precip/c384_max.npy') + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') - c48_min = np.load('data/only_precip/c48_min.npy') - c48_max = np.load('data/only_precip/c48_max.npy') + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: @@ -1366,6 +1400,7 @@ def train(self): self.opt.step() self.opt.zero_grad() + self.sched.step() accelerator.wait_for_everyone() @@ -1378,6 +1413,15 @@ def train(self): self.ema.ema_model.eval() with torch.no_grad(): + + psnrs = [] + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 for i, batch in enumerate(self.val_dl): @@ -1387,134 +1431,160 @@ def train(self): if i >= self.val_num_of_batch: break - num_samples = 5 - num_videos_per_batch = 1 - num_frames = 5 - img_size = 384 - img_channels = 1 + # num_samples = 5 + # num_videos_per_batch = 1 + # num_frames = 5 + # img_size = 384 + # img_channels = 1 - truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') - truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + # truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) - for k in range(num_samples): - videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) - pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) - - lres = lres[:, :, 0:1, :, :] - hres = hres[:, :, 0:1, :, :] - - crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) + # for k in range(num_samples): + # videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + # pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + loss = self.model(lres, hres, self.multi, self.flow) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + bases.append(base) + ress.append(res) + flowss.append(flows) + + # crps_index = calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size) psnr_index = piq.psnr(hres[:,2:,0:1,:,:], videos.clamp(0.0, 1.0)[:,:,0:1,:,:], data_range=1., reduction='none') + psnrs.append(psnr_index) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + psnr_index = torch.cat(psnrs, dim = 0).mean() + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + base = torch.cat(bases, dim = 0) + res = torch.cat(ress, dim = 0) + flows = torch.cat(flowss, dim = 0) + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + pscore = np.abs(np.percentile(output1, 99.99) - np.percentile(target1, 99.99)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + ) + ax1.hist( + target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) - videos_time_mean = videos.mean(dim = 1) - hres_time_mean = hres[:,2:,:,:,:].mean(dim = 1) - bias = videos_time_mean - hres_time_mean - norm = mpl.colors.Normalize(vmin = bias.min(), vmax = bias.max()) - sm = smap(norm, cmap) - b_c = [] - for l in range(num_videos_per_batch): - b_c.append(sm.to_rgba(bias[l,0,:,:].cpu().numpy())) - bias_color = np.stack(b_c, axis = 0) - - if not self.logscale: - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min - output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min - - else: - target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin - output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin - coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin - - if self.logscale: - target = np.exp(target) + c384_gmin - 1e-14 - output = np.exp(output) + c384_gmin - 1e-14 - coarse = np.exp(coarse) + c48_gmin - 1e-14 - - nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) - diff_output = (output - nn_upscale).flatten() - diff_target = (target - nn_upscale).flatten() - vmin = min(diff_output.min(), diff_target.min()) - vmax = max(diff_output.max(), diff_target.max()) - bins = np.linspace(vmin, vmax, 100 + 1) - - fig, ax = plt.subplots(1, 1, figsize=(6, 4)) - ax.hist( - diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True - ) - ax.hist( - diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True - ) - ax.set_xlim(vmin, vmax) - ax.legend() - ax.set_ylabel("Density") - ax.set_yscale("log") - - output1 = output.flatten() - target1 = target.flatten() - vmin1 = min(output1.min(), target1.min()) - vmax1 = max(output1.max(), target1.max()) - bins1 = np.linspace(vmin1, vmax1, 100 + 1) - - fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) - ax1.hist( - output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True - ) - ax1.hist( - target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True - ) - ax1.set_xlim(vmin1, vmax1) - ax1.legend() - ax1.set_ylabel("Density") - ax1.set_yscale("log") - - flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) - - for m in range(num_frames): + for m in range(num_frames): - flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + if self.flow == '3d': + flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) sm = smap(None, fcmap) for m in range(num_frames): flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) - - if self.logscale: - - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + + + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + if self.flow == '3d': accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) - - else: - - accelerator.log({"true_high": wandb.Video((hres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_low": wandb.Video((lres[:,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) - target = np.log(target - c384_gmin + 1e-14) - output = np.log(output - c384_gmin + 1e-14) - coarse = np.log(coarse - c48_gmin + 1e-14) - target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) - output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) - coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) - accelerator.log({"true_loghigh": wandb.Video((np.repeat(target, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - accelerator.log({"logsamples": wandb.Video((np.repeat(output, 3, axis=-3)*255).astype(np.uint8))}, step=self.step) - - - accelerator.log({"pattern_bias": wandb.Image((bias_color*255).astype(np.uint8), mode = 'RGBA')}, step=self.step) - accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) - accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) - accelerator.log({"psnr": psnr_index.mean()}, step=self.step) - accelerator.log({"crps": crps_index}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"psnr": psnr_index.mean()}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) milestone = self.step // self.save_and_sample_every diff --git a/projects/super_res/trainer.py b/projects/super_res/trainer.py index 29ad8358b1..6a02be3e35 100755 --- a/projects/super_res/trainer.py +++ b/projects/super_res/trainer.py @@ -8,9 +8,12 @@ def main(): if config.data_config["multi"]: - in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise - in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) - in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo else: From b0aef29f12fc210df346e083da41c5689cb1c8a2 Mon Sep 17 00:00:00 2001 From: cynicalsavant Date: Fri, 15 Sep 2023 14:20:46 -0700 Subject: [PATCH 9/9] Latest versions fixed for sampling, rvrt sampling still buggy, rvrt-isr-focal added additionally --- projects/super_res/config.py | 4 +- projects/super_res/config_focal.py | 44 + projects/super_res/config_infer.py | 16 +- projects/super_res/config_isr.py | 37 + projects/super_res/config_isr_infer.py | 37 + projects/super_res/config_rvrt_full.py | 50 + projects/super_res/config_rvrt_full_infer.py | 50 + .../data/ensemble_c384_trainstats/chl.pkl | Bin 0 -> 209 bytes .../data/ensemble_c384_trainstats/log_chl.pkl | Bin 0 -> 165 bytes .../data/ensemble_c384_trainstats/topo.pkl | Bin 0 -> 230 bytes .../data/ensemble_c48_trainstats/atm_chl.pkl | Bin 0 -> 679 bytes .../data/ensemble_c48_trainstats/chl.pkl | Bin 0 -> 752 bytes .../data/ensemble_c48_trainstats/log_chl.pkl | Bin 0 -> 188 bytes .../data/ensemblec384logtrainstats.py | 45 +- .../data/ensemblec384topotrainstats.py | 49 +- .../super_res/data/ensemblec384trainstats.py | 46 +- .../data/ensemblec48logtrainstats.py | 45 +- .../super_res/data/ensemblec48trainstats.py | 50 +- projects/super_res/data/load_data.py | 4 +- projects/super_res/data/vsrdata_ensemble.py | 16 +- .../super_res/model/autoreg_diffusion_mod.py | 202 +- .../model/autoreg_diffusion_mod_focal.py | 1799 +++++++++++++++++ .../model/denoising_diffusion_rvrt_full.py | 1611 +++++++++++++++ projects/super_res/model/isr_baseline.py | 568 ++++++ projects/super_res/model/op/deform_attn.py | 191 ++ .../model/op/deform_attn_cuda_kernel.cu | 867 ++++++++ .../model/op/deform_attn_cuda_pt109.cpp | 219 ++ .../model/op/deform_attn_cuda_pt110.cpp | 219 ++ .../super_res/model/op/deform_attn_ext.cpp | 75 + projects/super_res/sampler.py | 9 +- projects/super_res/sampler_isr.py | 38 + projects/super_res/sampler_rvrt_full.py | 103 + projects/super_res/trainer_focal.py | 90 + projects/super_res/trainer_isr.py | 45 + projects/super_res/trainer_rvrt_full.py | 109 + 35 files changed, 6352 insertions(+), 286 deletions(-) create mode 100644 projects/super_res/config_focal.py create mode 100644 projects/super_res/config_isr.py create mode 100644 projects/super_res/config_isr_infer.py create mode 100644 projects/super_res/config_rvrt_full.py create mode 100644 projects/super_res/config_rvrt_full_infer.py create mode 100644 projects/super_res/data/ensemble_c384_trainstats/chl.pkl create mode 100644 projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl create mode 100644 projects/super_res/data/ensemble_c384_trainstats/topo.pkl create mode 100644 projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl create mode 100644 projects/super_res/data/ensemble_c48_trainstats/chl.pkl create mode 100644 projects/super_res/data/ensemble_c48_trainstats/log_chl.pkl create mode 100644 projects/super_res/model/autoreg_diffusion_mod_focal.py create mode 100644 projects/super_res/model/denoising_diffusion_rvrt_full.py create mode 100644 projects/super_res/model/isr_baseline.py create mode 100644 projects/super_res/model/op/deform_attn.py create mode 100644 projects/super_res/model/op/deform_attn_cuda_kernel.cu create mode 100644 projects/super_res/model/op/deform_attn_cuda_pt109.cpp create mode 100644 projects/super_res/model/op/deform_attn_cuda_pt110.cpp create mode 100644 projects/super_res/model/op/deform_attn_ext.cpp create mode 100644 projects/super_res/sampler_isr.py create mode 100644 projects/super_res/sampler_rvrt_full.py create mode 100755 projects/super_res/trainer_focal.py create mode 100644 projects/super_res/trainer_isr.py create mode 100644 projects/super_res/trainer_rvrt_full.py diff --git a/projects/super_res/config.py b/projects/super_res/config.py index f61eb5ea23..4d8aadc37c 100644 --- a/projects/super_res/config.py +++ b/projects/super_res/config.py @@ -19,7 +19,7 @@ config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "2d-nomulti-ls-ensemble" +config.additional_note = "2d-nomulti-nols-ensemble" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" @@ -34,7 +34,7 @@ "channels": ["PRATEsfc_coarse"], "img_channel": 1, "img_size": 384, - "logscale": True, + "logscale": False, "multi": False, "flow": "2d", "minipatch": False diff --git a/projects/super_res/config_focal.py b/projects/super_res/config_focal.py new file mode 100644 index 0000000000..fb4a988761 --- /dev/null +++ b/projects/super_res/config_focal.py @@ -0,0 +1,44 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.dim = 64 +config.dim_mults = (1, 1, 2, 2, 3, 4) +config.learned_sinusoidal_cond = True, +config.random_fourier_features = True, +config.learned_sinusoidal_dim = 32 +config.diffusion_steps = 1500 +config.sampling_steps = 20 +config.loss = "focal" +config.objective = "pred_v" +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "2d-multi-ls-focal-ensemble" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + "channels": ["PRATEsfc_coarse"], + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_infer.py b/projects/super_res/config_infer.py index ce6305a97e..2520b3a1ca 100644 --- a/projects/super_res/config_infer.py +++ b/projects/super_res/config_infer.py @@ -8,23 +8,23 @@ config.random_fourier_features = True, config.learned_sinusoidal_dim = 32 config.diffusion_steps = 1500 -config.sampling_steps = 10 +config.sampling_steps = 20 config.loss = "l2" config.objective = "pred_v" -config.lr = 8e-5 -config.steps = 5000000 +config.lr = 1e-4 +config.steps = 700000 config.grad_acc = 1 config.val_num_of_batch = 5 -config.save_and_sample_every = 50 +config.save_and_sample_every = 20000 config.ema_decay = 0.995 config.amp = False config.split_batches = True -config.additional_note = "2d_multi_nols" +config.additional_note = "2d-nomulti-nols-ensemble" config.eval_folder = "./evaluate" config.results_folder = "./results" config.tensorboard_dir = "./tensorboard" -config.milestone = 1 -config.rollout = "full" +config.milestone = 2 +config.rollout = "partial" config.rollout_batch = 25 config.batch_size = 2 @@ -35,7 +35,7 @@ "img_channel": 1, "img_size": 384, "logscale": False, - "multi": True, + "multi": False, "flow": "2d", "minipatch": False }) diff --git a/projects/super_res/config_isr.py b/projects/super_res/config_isr.py new file mode 100644 index 0000000000..9100602f51 --- /dev/null +++ b/projects/super_res/config_isr.py @@ -0,0 +1,37 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "isr" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_isr_infer.py b/projects/super_res/config_isr_infer.py new file mode 100644 index 0000000000..54285da702 --- /dev/null +++ b/projects/super_res/config_isr_infer.py @@ -0,0 +1,37 @@ +from ml_collections import config_dict + +config = config_dict.ConfigDict() + +config.lr = 1e-4 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.995 +config.amp = False +config.split_batches = True +config.additional_note = "isr" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 2 +config.rollout = 'partial' +config.rollout_batch = 25 + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 7, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "flow": "2d", + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_rvrt_full.py b/projects/super_res/config_rvrt_full.py new file mode 100644 index 0000000000..f68fff7776 --- /dev/null +++ b/projects/super_res/config_rvrt_full.py @@ -0,0 +1,50 @@ +from ml_collections import config_dict + +#batch_size = 4 +config = config_dict.ConfigDict() + +config.dim = 120 +config.num_blocks = 6 +config.num_heads = 8 +config.depth = 8 +config.time_emb_dim = 32 +config.learned_sinusoidal_cond = True +config.diffusion_steps = 1500 +config.sampling_steps = 20 +# config.loss = "l2" +config.loss = "charbonnier" +config.objective = "pred_x0" +# config.lr = 8e-5 +config.lr = 1e-4 +# config.steps = 500000 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.999 +config.amp = False +config.split_batches = True +config.additional_note = "rvrt_full" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 1 +config.rollout = None +config.rollout_batch = None + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 6, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/config_rvrt_full_infer.py b/projects/super_res/config_rvrt_full_infer.py new file mode 100644 index 0000000000..12ab1a68f2 --- /dev/null +++ b/projects/super_res/config_rvrt_full_infer.py @@ -0,0 +1,50 @@ +from ml_collections import config_dict + +#batch_size = 4 +config = config_dict.ConfigDict() + +config.dim = 120 +config.num_blocks = 6 +config.num_heads = 8 +config.depth = 8 +config.time_emb_dim = 32 +config.learned_sinusoidal_cond = True +config.diffusion_steps = 1500 +config.sampling_steps = 20 +# config.loss = "l2" +config.loss = "charbonnier" +config.objective = "pred_x0" +# config.lr = 8e-5 +config.lr = 1e-4 +# config.steps = 500000 +config.steps = 700000 +config.grad_acc = 1 +config.val_num_of_batch = 5 +config.save_and_sample_every = 20000 +config.ema_decay = 0.999 +config.amp = False +config.split_batches = True +config.additional_note = "rvrt_full" +config.eval_folder = "./evaluate" +config.results_folder = "./results" +config.tensorboard_dir = "./tensorboard" +config.milestone = 2 +config.rollout = 'partial' +config.rollout_batch = 22 + +config.batch_size = 1 +config.data_config = config_dict.ConfigDict({ + "dataset_name": "c384", + "length": 6, + #"channels": ["UGRD10m_coarse","VGRD10m_coarse"], + "channels": ["PRATEsfc_coarse"], + #"img_channel": 2, + "img_channel": 1, + "img_size": 384, + "logscale": True, + "multi": True, + "minipatch": False +}) + +config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.additional_note}" +config.model_name = f"c384-{config.data_config['channels']}-{config.additional_note}" \ No newline at end of file diff --git a/projects/super_res/data/ensemble_c384_trainstats/chl.pkl b/projects/super_res/data/ensemble_c384_trainstats/chl.pkl new file mode 100644 index 0000000000000000000000000000000000000000..fa1744dbe2e5c7d06f4765cd1bd2d0d5620fa272 GIT binary patch literal 209 zcmZo*nR<)?0&1u9a0CQ7hPW1|B~PiHqS3>go0&JIM>MZAx1drlIlm}XFSj(OBr~z7 zD6w)%4^MniYI1&FaY<2Wa>^-b0C6xuKAoWaXCQ~{)AX3adj5bsJ{QSKB0|A)uW+<7G)alFua@!Oj6X=@6iYXav Z8Jr+PGq~YK3IP@Mv~t^m6qOe10RWkYP-FlA literal 0 HcmV?d00001 diff --git a/projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl b/projects/super_res/data/ensemble_c384_trainstats/log_chl.pkl new file mode 100644 index 0000000000000000000000000000000000000000..de65b8ebb40e1f4e23bdacee7332795c5e9d64d1 GIT binary patch literal 165 zcmZo*nL3LB0&1u9a0CQ7hPW1|B~PiHqS3>go0&JIM>MZAx1drlIlm}XFSj(OBr~z7 zD6w)%4_k3^VoqYwl*vj41qs!19E literal 0 HcmV?d00001 diff --git a/projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl b/projects/super_res/data/ensemble_c48_trainstats/atm_chl.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e570cc31697c057760d14cd19495d6a41df32513 GIT binary patch literal 679 zcmZ{f%~I1)7=}|)P=bhv|DYg(1i@Mntv{gHq8;r_hdHY(oXI3jB7*-S|N=&C?#fghTh5_AB z`7H25s7{y$QMcrDVz(5eZSt)+wz`@byG|!eC{7*HQ0!4!*y_Dvu%UdyP_1C&*ba^} z4C{@;;vv+cZscmvi~;U=y^4{&-5s1@7}Xo9QQ)&RoNP9mFaPm#a5WUdn7vmkMsHu# zaf&c*_x@+S$7!$V9mpm;ZaTwoR+zj|PXAi^QO7yLc{?>mcx_rKm-|az$Z%2j)PDGq z2v0Ih$?(@dTl=_7xMF)!F)gSWhO3gAtT{U@B+~SHkzQwj1lkZsVcGvVna^J)jxh=Ep!+d8^f4y>s+d21& zXuZR*DqH`~uYAE>!o8fkBB=We436HpY_DFXUd1=r$uVq-znIDVJCkz|1 fGx#~Qj;Dlb&Ri4JCc`sH{r>)~isyv7ou;M#@08M} literal 0 HcmV?d00001 diff --git a/projects/super_res/data/ensemble_c48_trainstats/chl.pkl b/projects/super_res/data/ensemble_c48_trainstats/chl.pkl new file mode 100644 index 0000000000000000000000000000000000000000..bb1926631ab59f08b09e27ae857a88ba7bc421b7 GIT binary patch literal 752 zcmY+C-A)rh7>2jqs<_&MiU|0Re~Ul`!5^TAbPFgewGJqI1le6mmc+Ka+iw$UvWX`K z>rHS2#>89Tz&qhObf!O(aSrBSzUTd(`7-lWkD-}Pucm1wwzRoi-}86;nj6;3NW@5* zj?OEU{VH_78dVM&S-%#Rvz6#G?Ym*!td;O^ z?g%{$=S@2?+JSbdaZom3I7#lLuqF zrRwFD=9FY678NB{aup{h<|Gzz6|z9&xl&3h3sSiXStFQ$8q+Ka**ZJg3)zDTIW)YP zBN%~VHievietus6fdEW+Gn5o^B{_2xud-{gmUB2z$lX@R6I93xu{p7#kS~KJgTtBQ Y-SY|ao39HWDCBP|6bLF5EG^ap04= diff for time in times): + times.append(new_time) + return times + +def calculate_crps(truth, pred, num_samples, num_videos_per_batch, num_frames, img_channels, img_size): + truth_cdf = np.zeros((256, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for i in range(256): + truth_cdf[i, :, :, :, :, :, :] = (truth <= i).astype('uint8') + pred_cdf = np.zeros((256, num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + for j in range(256): + pred_cdf[j, :, :, :, :, :, :, :] = (pred <= j).astype('uint8') + red_pred_cdf = pred_cdf.mean(1) + temp = np.square(red_pred_cdf - truth_cdf) + temp_dz = temp.sum(0) + temp_dz_dd = temp_dz.mean(axis = (3, 4, 5)) + temp_dz_dd_dt = temp_dz_dd.mean(2) + return temp_dz_dd_dt.mean() + +def save_image(tensor, path): + im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) + im.save(path) + return None + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# flow modules + +def gaussian_pyramids(input, base_sigma = 1, m = 5): + + output = [input] + N, C, H, W = input.shape + + kernel = filters.get_gaussian_kernel2d((5, 5), (base_sigma, base_sigma))#.unsqueeze(0) + + for i in range(m): + + input = filters.filter2d(input, kernel) + + if i == 0: + + output.append(input) + + else: + + tmp = input + + for j in range(i): + + tmp = F.interpolate(tmp, scale_factor = 2., mode = 'bilinear', align_corners = True) + + output.append(tmp) + + input = F.interpolate(input, scale_factor = 0.5) + + return torch.stack(output, 2) + +def scale_space_warp(input, flow): + + N, C, H, W = input.shape + + assert flow.shape == (N, 3, H, W) + + flow = flow.unsqueeze(0) + #multi_scale = gaussian_pyramids(input, self.base_scale, self.gaussian_dim) + multi_scale = gaussian_pyramids(input, 1.0, 5) + + h = torch.arange(H, device=input.device, dtype=input.dtype) + w = torch.arange(W, device=input.device, dtype=input.dtype) + d = torch.zeros(1, device=input.device, dtype=input.dtype) + + grid = torch.stack(torch.meshgrid(d, h, w)[::-1], -1).unsqueeze(0) + grid = grid.expand(N, -1, -1, -1, -1) + flow = flow.permute(1, 0, 3, 4, 2) # N, 1, H, W, 3 + + # reparameterization + # var_channel = (flow[..., -1].exp())**2 + # var_space = [0.] + [(2.**i * self.base_scale)**2 for i in range(self.gaussian_dim)] + # d_offset = var_to_position(var_channel, var_space).unsqueeze(-1) + d_offset = flow[..., -1].clamp(min=-1.0, max=1.0).unsqueeze(-1) + + flow = torch.cat((flow[..., :2], d_offset), -1) + flow_grid = flow + grid + flow_grid[..., 0] = 2.0 * flow_grid[..., 0] / max(W - 1.0, 1.0) - 1.0 + flow_grid[..., 1] = 2.0 * flow_grid[..., 1] / max(H - 1.0, 1.0) - 1.0 + + warped = F.grid_sample(multi_scale, flow_grid, padding_mode = "border", align_corners = True).squeeze(2) + + return warped + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='border', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + +# small helper modules + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + +def Downsample(dim, dim_out = None): + return nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), + nn.Conv2d(dim * 4, default(dim_out, dim), 1) + ) + +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class RandomOrLearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim, is_random = False): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q * self.scale + + sim = einsum('b h d i, b h d j -> b h i j', q, k) + attn = sim.softmax(dim = -1) + out = einsum('b h i j, b h d j -> b h i d', attn, v) + + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + +# model + +class Unet(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + self_condition = False, + resnet_block_groups = 8, + learned_variance = False, + learned_sinusoidal_cond = False, + random_fourier_features = False, + learned_sinusoidal_dim = 16 + ): + super().__init__() + + # determine dimensions + + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features + + if self.random_or_learned_sinusoidal_cond: + sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = SinusoidalPosEmb(dim) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), + block_klass(2*dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x, time, context, x_self_cond = None): + if self.self_condition: + x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + + x = self.init_conv(x) + r = x.clone() + + t = self.time_mlp(time) + + h = [] + + count = 0 + + for block1, block2, attn, downsample in self.downs: + x = torch.cat((x, context[count]), dim = 1) + count += 1 + x = block1(x, t) + h.append(x) + + x = torch.cat((x, context[count]), dim = 1) + count += 1 + x = block2(x, t) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x, t) + return self.final_conv(x) + +class Flow(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + resnet_block_groups = 8, + ): + super().__init__() + + # determine dimensions + + self.channels = channels + input_channels = channels + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in), + block_klass(dim_in, dim_in), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out), + block_klass(dim_out + dim_in, dim_out), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x): + + x = self.init_conv(x) + r = x.clone() + + h = [] + context = [] + for block1, block2, attn, downsample in self.downs: + x = block1(x) + h.append(x) + context.append(x) + x = block2(x) + x = attn(x) + h.append(x) + context.append(x) + x = downsample(x) + + x = self.mid_block1(x) + x = self.mid_attn(x) + x = self.mid_block2(x) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x) + return self.final_conv(x), context + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def linear_beta_schedule(timesteps): + """ + linear schedule, proposed in original ddpm paper + """ + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) + +def cosine_beta_schedule(timesteps, s = 0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps + alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + +def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): + """ + sigmoid schedule + proposed in https://arxiv.org/abs/2212.11972 - Figure 8 + better for images > 64x64, when used during training + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps + v_start = torch.tensor(start / tau).sigmoid() + v_end = torch.tensor(end / tau).sigmoid() + alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + +class GaussianDiffusion(nn.Module): + def __init__( + self, + model, + flow, + *, + image_size, + in_ch, + timesteps = 1200, + sampling_timesteps = None, + loss_type = 'l1', + objective = 'pred_noise', + beta_schedule = 'sigmoid', + schedule_fn_kwargs = dict(), + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1, + ddim_sampling_eta = 0., + auto_normalize = True + ): + super().__init__() + + self.model = model + + self.umodel = context_net(upscale = 8, in_chans = in_ch, out_chans = 1, img_size = 48, window_size = 8, + img_range = 1., depths = [6, 6, 6, 6, 6, 6, 6], embed_dim = 200, + num_heads = [8, 8, 8, 8, 8, 8, 8], + mlp_ratio = 2, upsampler = 'pixelshuffle', resi_connection = '3conv') + self.flow = flow + self.upsample = nn.UpsamplingBilinear2d(scale_factor=8) + + self.channels = self.model.channels + self.self_condition = self.model.self_condition + + self.image_size = image_size + + self.objective = objective + + assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' + + if beta_schedule == 'linear': + beta_schedule_fn = linear_beta_schedule + elif beta_schedule == 'cosine': + beta_schedule_fn = cosine_beta_schedule + elif beta_schedule == 'sigmoid': + beta_schedule_fn = sigmoid_beta_schedule + else: + raise ValueError(f'unknown beta schedule {beta_schedule}') + + betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # sampling related parameters + + self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training + + assert self.sampling_timesteps <= timesteps + self.is_ddim_sampling = self.sampling_timesteps < timesteps + self.ddim_sampling_eta = ddim_sampling_eta + + # helper function to register buffer from float64 to float32 + + register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # calculate p2 reweighting + + register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) + + # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False + + self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity + self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_noise_from_start(self, x_t, t, x0): + return ( + (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + ) + + def predict_v(self, x_start, t, noise): + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + + def predict_start_from_v(self, x_t, t, v): + return ( + extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def q_posterior(self, x_start, x_t, t): + + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def model_predictions(self, x, t, l_cond, context, x_self_cond = None, clip_x_start = False): + + model_output = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) + + maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity + + if self.objective == 'pred_noise': + pred_noise = model_output + x_start = self.predict_start_from_noise(x, t, pred_noise) + x_start = maybe_clip(x_start) + + elif self.objective == 'pred_x0': + x_start = model_output + x_start = maybe_clip(x_start) + pred_noise = self.predict_noise_from_start(x, t, x_start) + + elif self.objective == 'pred_v': + v = model_output + x_start = self.predict_start_from_v(x, t, v) + x_start = maybe_clip(x_start) + pred_noise = self.predict_noise_from_start(x, t, x_start) + + return ModelPrediction(pred_noise, x_start) + + def p_mean_variance(self, x, t, context, x_self_cond = None, clip_denoised = True): + + preds = self.model_predictions(x, t, context, x_self_cond) + x_start = preds.pred_x_start + + if clip_denoised: + x_start.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) + return model_mean, posterior_variance, posterior_log_variance, x_start + + @torch.no_grad() + def p_sample(self, x, t: int, context, x_self_cond = None): + + batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, context = context, x_self_cond = x_self_cond, clip_denoised = True) + noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 + pred_img = model_mean + (0.5 * model_log_variance).exp() * noise + return pred_img, x_start + + @torch.no_grad() + def p_sample_loop(self, shape, context, return_all_timesteps = False): + + device = self.betas.device + + img = torch.randn(shape, device = device) + imgs = [img] + + x_start = None + + for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + self_cond = x_start if self.self_condition else None + img, x_start = self.p_sample(img, t, context, self_cond) + imgs.append(img) + + ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) + + return ret + + @torch.no_grad() + def ddim_sample(self, shape, l_cond, context, return_all_timesteps = False): + + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective + + times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps + times = list(reversed(times.int().tolist())) + time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] + + img = torch.randn(shape, device = device) + imgs = [img] + + x_start = None + + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): + time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + self_cond = x_start if self.self_condition else None + pred_noise, x_start, *_ = self.model_predictions(img, time_cond, l_cond, context, self_cond, clip_x_start = True) + + imgs.append(img) + + if time_next < 0: + img = x_start + continue + + alpha = self.alphas_cumprod[time] + alpha_next = self.alphas_cumprod[time_next] + + sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c = (1 - alpha_next - sigma ** 2).sqrt() + + noise = torch.randn_like(img) + + img = x_start * alpha_next.sqrt() + \ + c * pred_noise + \ + sigma * noise + + imgs.append(img) + ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) + + return ret + + @torch.no_grad() + def sample(self, lres, hres, multi, flow_mode, return_all_timesteps = False): + + b, f, c, h, w = lres.shape + + if multi: + + topo = hres[:, :, 1:2, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(8*h, 8*w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) + + lres = self.normalize(lres) + ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + + sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample + + l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + + r = torch.roll(l, -1, 1) + ures_flow = rearrange(ures[:, 1:(f-1), :, :, :], 'b t c h w -> (b t) c h w') + + m = lres.clone() + m1 = rearrange(m, 'b t c h w -> (b t) c h w') + m1 = self.upsample(m1) + m1 = rearrange(m1, '(b t) c h w -> b t c h w', t = f) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + + m1 = torch.roll(m1, -2, 1) + + stack = torch.cat((l, r, m1), 2) + stack = stack[:, :(f-2), :, :, :] + stack = rearrange(stack, 'b t c h w -> (b t) c h w') + + flow, context = self.flow(stack) + + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped + + res = sample_fn((b * (f - 2), 1, 8 * h, 8 * w), l_cond, context, return_all_timesteps = return_all_timesteps) + sres = warped + res + sres = rearrange(sres, '(b t) c h w -> b t c h w', b = b) + + warped = rearrange(warped, '(b t) c h w -> b t c h w', b = b) + res = rearrange(res, '(b t) c h w -> b t c h w', b = b) + flow = rearrange(flow, '(b t) c h w -> b t c h w', b = b) + + if flow_mode == '2d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), flow + + elif flow_mode == '3d': + + return self.unnormalize(sres), self.unnormalize(warped), self.unnormalize(res), self.unnormalize(flow) + + @torch.no_grad() + def interpolate(self, x1, x2, t = None, lam = 0.5): + + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device = device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + @property + def loss_fn(self): + + if self.loss_type == 'l1': + return F.l1_loss + elif self.loss_type == 'l2': + return F.mse_loss + elif self.loss_type == 'focal': + return focal_mse_loss + else: + raise ValueError(f'invalid loss type {self.loss_type}') + + def p_losses(self, stack, hres, lres, ures, t, multi, flow_mode, topo = None, noise = None): + + f = hres.shape[1] + + stack = rearrange(stack, 'b t c h w -> (b t) c h w') + ures_flow = rearrange(ures[:, 1:(f - 1), :, :, :], 'b t c h w -> (b t) c h w') + + flow, context = self.flow(stack) + + if flow_mode == '3d': + + warped = scale_space_warp(ures_flow, flow) + + elif flow_mode == '2d': + + flow = self.unnormalize(flow) + warped = flow_warp(ures_flow, flow.permute(0, 2, 3, 1)) + + x_start = rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w') + x_start = x_start - warped + + if multi: + + # l_cond = torch.cat((self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + l_cond = torch.cat((warped, self.upsample(rearrange(lres[:, 2:, 1:, :, :], 'b t c h w -> (b t) c h w')), rearrange(topo[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')), dim = 1) + + else: + + # l_cond = self.upsample(rearrange(lres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w')) + l_cond = warped + + del f + + noise = default(noise, lambda: torch.randn_like(x_start)) + + # noise sample + + x = self.q_sample(x_start = x_start, t = t, noise = noise) + + # if doing self-conditioning, 50% of the time, predict x_start from current set of times + # and condition with unet with that + # this technique will slow down training by 25%, but seems to lower FID significantly + + x_self_cond = None + if self.self_condition and random() < 0.5: + with torch.no_grad(): + x_self_cond = self.model_predictions(x, t).pred_x_start + x_self_cond.detach_() + + # predict and take gradient step + + model_out = self.model(torch.cat((x, l_cond), 1), t, context, x_self_cond) + + if self.objective == 'pred_noise': + target = noise + elif self.objective == 'pred_x0': + target = x_start + elif self.objective == 'pred_v': + v = self.predict_v(x_start, t, noise) + target = v + else: + raise ValueError(f'unknown objective {self.objective}') + + loss = self.loss_fn(model_out, target, reduction = 'none') + loss = reduce(loss, 'b ... -> b (...)', 'mean') + + loss = loss * extract(self.p2_loss_weight, t, loss.shape) + + loss1 = self.loss_fn(ures, hres, reduction = 'none') + loss1 = reduce(loss1, 'b ... -> b (...)', 'mean') + + loss2 = self.loss_fn(warped, rearrange(hres[:, 2:, :, :, :], 'b t c h w -> (b t) c h w'), reduction = 'none') + loss2 = reduce(loss2, 'b ... -> b (...)', 'mean') + + return loss.mean()*1.7 + loss1.mean()*1.0 + loss2.mean()*0.3 + + def forward(self, lres, hres, multi, flow_mode, *args, **kwargs): + + b, f, c, h, w, device = *hres.shape, hres.device + + t = torch.randint(0, self.num_timesteps, (b*(f-2),), device=device).long() + + if multi: + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + low_chans = lres[:, :, 1:, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + high_chans = rearrange(F.interpolate(rearrange(low_chans, 'b t c h w -> (b t) c h w'), size=(h, w), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + + if multi: + + ures = self.umodel(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + + else: + + ures = self.umodel(rearrange(lres, 'b t c h w -> (b t) c h w')) + + ures = rearrange(ures, '(b t) c h w -> b t c h w', b = b) + + lres = self.normalize(lres) + hres = self.normalize(hres) + ures = self.normalize(ures) + + if multi: + + topo = self.normalize(topo) + + l = ures.clone() + + if multi: + + l = torch.cat((l, high_chans, topo), dim = 2) + + r = torch.roll(l, -1, 1) + + m = lres.clone() + m1 = rearrange(m, 'b t c h w -> (b t) c h w') + m1 = self.upsample(m1) + m1 = rearrange(m1, '(b t) c h w -> b t c h w', b = b) + + if multi: + + m1 = torch.cat((m1, topo), dim = 2) + + m1 = torch.roll(m1, -2, 1) + + stack = torch.cat((l, r, m1), 2) + stack = stack[:, :(f-2), :, :, :] + + if multi: + + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, topo, *args, **kwargs) + + else: + + return self.p_losses(stack, hres, lres, ures, t, multi, flow_mode, None, *args, **kwargs) + +# trainer class + +class Trainer(object): + def __init__( + self, + diffusion_model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 1, + eval_folder = './evaluate', + results_folder = './results', + val_num_of_batch = 2, + amp = False, + fp16 = False, + split_batches = True + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("climate", + init_kwargs={ + "wandb": { + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] + + self.model = diffusion_model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + self.image_size = diffusion_model.image_size + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + #self.sched = ReduceLROnPlateau(self.opt, 'min', factor = 0.5, patience = 5, min_lr = 1e-6, verbose = False) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = val_dl + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') + + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') + + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + + with self.accelerator.autocast(): + + loss = self.model(lres, hres, self.multi, self.flow) + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + self.sched.step() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + # num_samples = 5 + # num_videos_per_batch = 1 + # num_frames = 5 + # img_size = 384 + # img_channels = 1 + + # truth = np.zeros((1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # pred = np.zeros((num_samples, 1, num_videos_per_batch, num_frames, img_channels, img_size, img_size), dtype = 'uint8') + # truth[0,:,:,:,:,:] = (hres[:,2:,0:1,:,:].repeat(1,1,1,1,1).cpu().numpy()*255).astype(np.uint8) + + # for k in range(num_samples): + # videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + # pred[k,0,:,:,:,:] = (videos.clamp(0.0, 1.0)[:,:,0:1,:,:].repeat(1,1,1,1,1).detach().cpu().numpy()*255).astype(np.uint8) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + loss = self.model(lres, hres, self.multi, self.flow) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + bases.append(base) + ress.append(res) + flowss.append(flows) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + base = torch.cat(bases, dim = 0) + res = torch.cat(ress, dim = 0) + flows = torch.cat(flowss, dim = 0) + del vids, vlosses, hr, lr, bases, ress, flowss + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,2:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,2:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + flow_d = np.zeros((1, num_frames, 3, img_size, img_size)) + + for m in range(num_frames): + + flow_d[0,m,:,:,:] = np.transpose(flow_vis.flow_to_color(flows.clamp(0, 1)[0,m,:2,:,:].permute(1,2,0).cpu().numpy(), convert_to_bgr = True), (2,0,1)) + + if self.flow == '3d': + + flow_s = np.zeros((1, num_frames, 3, img_size, img_size)) + sm = smap(None, fcmap) + + for m in range(num_frames): + + flow_s[0,m,:,:,:] = np.transpose(sm.to_rgba(flows.clamp(0, 1)[0,m,2,:,:].cpu().numpy())[:,:,:3], (2,0,1)) + + + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"pred": wandb.Video((base.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"flow_d": wandb.Video((flow_d*255).astype(np.uint8))}, step=self.step) + if self.flow == '3d': + accelerator.log({"flow_s": wandb.Video((flow_s*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"res": wandb.Video((res[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] + + with torch.no_grad(): + + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + + + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 250 samples per tile + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + + if self.multi: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + else: + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + + + if self.logscale: + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + else: + + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min']) / (c48_chl["PRATEsfc_coarse"]['max'] - c48_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_chl["PRATEsfc"]['min']) / (c384_chl["PRATEsfc"]['max'] - c384_chl["PRATEsfc"]['min']) + + if self.multi: + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + videos, base, res, flows = self.ema.ema_model.sample(lres, hres, self.multi, self.flow) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,2:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/denoising_diffusion_rvrt_full.py b/projects/super_res/model/denoising_diffusion_rvrt_full.py new file mode 100644 index 0000000000..3df61b513a --- /dev/null +++ b/projects/super_res/model/denoising_diffusion_rvrt_full.py @@ -0,0 +1,1611 @@ +import os +import math +from pathlib import Path +from random import random, randint +from functools import partial, reduce, lru_cache +from collections import namedtuple +from operator import mul + +import numpy as np +import cv2 +from scipy.stats import wasserstein_distance + +import xarray as xr + +import torch +from torch import nn +import torch.nn.functional as F +import wandb + +import piq +import pickle + +from torchvision.transforms.functional import crop + +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable as smap + +from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR + +from einops import rearrange +import einops +from einops.layers.torch import Rearrange + +from PIL import Image + +from tqdm.auto import tqdm +from ema_pytorch import EMA + +from accelerate import Accelerator +from distutils.version import LooseVersion +from .op.deform_attn import deform_attn, DeformAttnPack + +# constants + +ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) + +# helpers functions + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-9): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y, reduction = None): + diff = x - y + loss = torch.sqrt((diff * diff) + self.eps) + return loss + +def save_image(tensor, path): + im = Image.fromarray((tensor[:,:,:3] * 255).astype(np.uint8)) + im.save(path) + return None + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# model helpers + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear' or 'nearest4'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + + Returns: + Tensor: Warped image or feature map. + """ + n, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, dtype=x.dtype, device=x.device), + torch.arange(0, w, dtype=x.dtype, device=x.device)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + return output + + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + + +class BasicModule(nn.Module): + """Basic Module for SpyNet. + """ + + def __init__(self): + super(BasicModule, self).__init__() + + self.basic_module = nn.Sequential( + nn.Conv2d(in_channels=26, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), + nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) + + def forward(self, tensor_input): + return self.basic_module(tensor_input) + + +class SpyNet(nn.Module): + """SpyNet architecture. + + Args: + load_path (str): path for pretrained SpyNet. Default: None. + return_levels (list[int]): return flows of different levels. Default: [5]. + """ + + def __init__(self, load_path=None, return_levels=[5]): + super(SpyNet, self).__init__() + self.return_levels = return_levels + self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) + + + def process(self, ref, supp, w, h, w_floor, h_floor): + flow_list = [] + + ref = [ref] + supp = [supp] + + for level in range(5): + ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) + supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) + + flow = ref[0].new_zeros( + [ref[0].size(0), 2, + int(math.floor(ref[0].size(2) / 2.0)), + int(math.floor(ref[0].size(3) / 2.0))]) + + for level in range(len(ref)): + upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 + + if upsampled_flow.size(2) != ref[level].size(2): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') + if upsampled_flow.size(3) != ref[level].size(3): + upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') + + flow = self.basic_module[level](torch.cat([ + ref[level], + flow_warp( + supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), + upsampled_flow + ], 1)) + upsampled_flow + + if level in self.return_levels: + scale = 2 ** (5 - level) # level=5 (scale=1), level=4 (scale=2), level=3 (scale=4), level=2 (scale=8) + flow_out = F.interpolate(input=flow, size=(h // scale, w // scale), mode='bilinear', + align_corners=False) + flow_out[:, 0, :, :] *= float(w // scale) / float(w_floor // scale) + flow_out[:, 1, :, :] *= float(h // scale) / float(h_floor // scale) + flow_list.insert(0, flow_out) + + return flow_list + + def forward(self, ref, supp): + assert ref.size() == supp.size() + + h, w = ref.size(2), ref.size(3) + w_floor = math.floor(math.ceil(w / 32.0) * 32.0) + h_floor = math.floor(math.ceil(h / 32.0) * 32.0) + + ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) + + flow_list = self.process(ref, supp, w, h, w_floor, h_floor) + + return flow_list[0] if len(flow_list) == 1 else flow_list + + +class GuidedDeformAttnPack(DeformAttnPack): + """Guided deformable attention module. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + max_residue_magnitude (int): The maximum magnitude of the offset residue. Default: 10. + Ref: + Recurrent Video Restoration Transformer with Guided Deformable Attention + + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(GuidedDeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv3d(self.in_channels * (1 + self.clip_size) + self.clip_size * 2, 64, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv3d(64, self.clip_size * self.deformable_groups * self.attn_size * 2, kernel_size=(1, 1, 1), + padding=(0, 0, 0)), + ) + self.init_offset() + + # proj to a higher dimension can slightly improve the performance + self.proj_channels = int(self.in_channels * 2) + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.proj_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.proj_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + + def init_offset(self): + if hasattr(self, 'conv_offset'): + self.conv_offset[-1].weight.data.zero_() + self.conv_offset[-1].bias.data.zero_() + + def forward(self, q, k, v, v_prop_warped, flows, return_updateflow): + offset1, offset2 = torch.chunk(self.max_residue_magnitude * torch.tanh( + self.conv_offset(torch.cat([q] + v_prop_warped + flows, 2).transpose(1, 2)).transpose(1, 2)), 2, dim=2) + offset1 = offset1 + flows[0].flip(2).repeat(1, 1, offset1.size(2) // 2, 1, 1) + offset2 = offset2 + flows[1].flip(2).repeat(1, 1, offset2.size(2) // 2, 1, 1) + offset = torch.cat([offset1, offset2], dim=2).flatten(0, 1) + + b, t, c, h, w = offset1.shape + q = self.proj_q(q).view(b * t, 1, self.proj_channels, h, w) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size).view(b, t, self.proj_channels, h, + w) + v = self.proj(v) + v = v + self.mlp(v) + + if return_updateflow: + return v, offset1.view(b, t, c // 2, 2, h, w).mean(2).flip(2), offset2.view(b, t, c // 2, 2, h, w).mean( + 2).flip(2) + else: + return v + +def window_partition(x, window_size): + """ Partition the input into windows. Attention will be conducted within the windows. + + Args: + x: (B, D, H, W, C) + window_size (tuple[int]): window size + + Returns: + windows: (B*num_windows, window_size*window_size, C) + """ + B, D, H, W, C = x.shape + x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], + window_size[2], C) + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) + + return windows + + +def window_reverse(windows, window_size, B, D, H, W): + """ Reverse windows back to the original input. Attention was conducted within the windows. + + Args: + windows: (B*num_windows, window_size, window_size, C) + window_size (tuple[int]): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, D, H, W, C) + """ + x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], + window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) + + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """ Get the window size and the shift size """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + +@lru_cache() +def compute_mask(D, H, W, window_size, shift_size, device): + """ Compute attnetion mask for input of size (D, H, W). @lru_cache caches each stage results. """ + + img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 + cnt = 0 + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 + mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + +class WindowAttention(nn.Module): + """ Window based multi-head self attention. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The temporal length, height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None): + super().__init__() + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), + num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH + self.register_buffer("relative_position_index", self.get_position_index(window_size)) + self.qkv_self = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, N, N) or None + """ + + # self attention + B_, N, C = x.shape + qkv = self.qkv_self(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C + x_out = self.attention(q, k, v, mask, (B_, N, C)) + + # projection + x = self.proj(x_out) + + return x + + def attention(self, q, k, v, mask, x_shape): + B_, N, C = x_shape + attn = (q * self.scale) @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:N, :N].reshape(-1)].reshape(N, N, -1) # Wd*Wh*Ww, Wd*Wh*Ww,nH + attn = attn + relative_position_bias.permute(2, 0, 1).unsqueeze(0) # B_, nH, N, N + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask[:, :N, :N].unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = F.softmax(attn, -1, dtype=q.dtype) # Don't use attn.dtype after addition! + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + return x + + def get_position_index(self, window_size): + ''' Get pair-wise relative position index for each token inside the window. ''' + + coords_d = torch.arange(window_size[0]) + coords_h = torch.arange(window_size[1]) + coords_w = torch.arange(window_size[2]) + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 2] += window_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1) + relative_coords[:, :, 1] *= (2 * window_size[2] - 1) + relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww + + return relative_position_index + +class STL(nn.Module): + """ Swin Transformer Layer (STL). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (tuple[int]): Shift size for mutual and self attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True. + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=(2, 8, 8), + shift_size=(0, 0, 0), + mlp_ratio=2., + qkv_bias=True, + qk_scale=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False + ): + super().__init__() + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.use_checkpoint_attn = use_checkpoint_attn + self.use_checkpoint_ffn = use_checkpoint_ffn + + assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" + assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention(dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale) + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + def forward_part1(self, x, mask_matrix): + B, D, H, W, C = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + + x = self.norm1(x) + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] + pad_b = (window_size[1] - H % window_size[1]) % window_size[1] + pad_r = (window_size[2] - W % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1), mode='constant') + + _, Dp, Hp, Wp, _ = x.shape + # cyclic shift + if any(i > 0 for i in shift_size): + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C + + # attention / shifted attention + attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C + + # merge windows + attn_windows = attn_windows.view(-1, *(window_size + (C,))) + shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C + + # reverse cyclic shift + if any(i > 0 for i in shift_size): + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :D, :H, :W, :] + + return x + + def forward_part2(self, x): + return self.mlp(self.norm2(x)) + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, D, H, W, C). + mask_matrix: Attention mask for cyclic shift. + """ + + # attention + x = x + self.forward_part1(x, mask_matrix) + + # feed-forward + x = x + self.forward_part2(x) + + return x + + +class STG(nn.Module): + """ Swin Transformer Group (STG). + + Args: + dim (int): Number of feature channels + input_resolution (tuple[int]): Input resolution. + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (tuple[int]): Local window size. Default: (6,8,8). + shift_size (tuple[int]): Shift size for mutual and self attention. Default: None. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + use_checkpoint_attn (bool): If True, use torch.checkpoint for attention modules. Default: False. + use_checkpoint_ffn (bool): If True, use torch.checkpoint for feed-forward modules. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size=[2, 8, 8], + shift_size=None, + mlp_ratio=2., + qkv_bias=False, + qk_scale=None, + norm_layer=nn.LayerNorm, + use_checkpoint_attn=False, + use_checkpoint_ffn=False, + ): + super().__init__() + self.input_resolution = input_resolution + self.window_size = window_size + self.shift_size = list(i // 2 for i in window_size) if shift_size is None else shift_size + + # build blocks + self.blocks = nn.ModuleList([ + STL( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=[0, 0, 0] if i % 2 == 0 else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint_attn=use_checkpoint_attn, + use_checkpoint_ffn=use_checkpoint_ffn + ) + for i in range(depth)]) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, C, D, H, W). + """ + # calculate attention mask for attention + B, C, D, H, W = x.shape + window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) + x = rearrange(x, 'b c d h w -> b d h w c') + Dp = int(np.ceil(D / window_size[0])) * window_size[0] + Hp = int(np.ceil(H / window_size[1])) * window_size[1] + Wp = int(np.ceil(W / window_size[2])) * window_size[2] + attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) + + for blk in self.blocks: + x = blk(x, attn_mask) + + x = x.view(B, D, H, W, -1) + x = rearrange(x, 'b d h w c -> b c d h w') + + return x + +class RSTB(nn.Module): + """ Residual Swin Transformer Block (RSTB). + + Args: + kwargs: Args for RSTB. + """ + + def __init__(self, groups = 8, **kwargs): + super(RSTB, self).__init__() + self.input_resolution = kwargs['input_resolution'] + + self.residual_group = STG(**kwargs) + self.linear = nn.Linear(kwargs['dim'], kwargs['dim']) + self.proj = nn.Conv3d(kwargs['dim'], + kwargs['dim'], + kernel_size=(1,3,3), + padding=(0,1,1), + groups=groups) + self.norm = nn.GroupNorm(groups, kwargs['dim']) + self.act = nn.SiLU() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + + x = self.act(x) + + return x + self.linear(self.residual_group(x).transpose(1, 4)).transpose(1, 4) + +class RSTBWithInputConv(nn.Module): + """RSTB with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + kernel_size (int): Size of kernel of the first conv. + stride (int): Stride of the first conv. + group (int): Group of the first conv. + num_blocks (int): Number of residual blocks. Default: 2. + **kwarg: Args for RSTB. + """ + + def __init__(self, in_channels=3, kernel_size=(1, 3, 3), stride=1, groups=1, num_blocks=2, **kwargs): + super(RSTBWithInputConv, self).__init__() + + self.in_channels = in_channels + self.init_conv = nn.Conv3d(in_channels, + kwargs['dim'], + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[2] // 2), + groups=groups) + + self.init_norm = nn.LayerNorm(kwargs['dim']) + + # RSTB blocks + #kwargs['use_checkpoint_attn'] = kwargs.pop('use_checkpoint_attn')[0] + #kwargs['use_checkpoint_ffn'] = kwargs.pop('use_checkpoint_ffn')[0] + + #main.append(make_layer(RSTB, num_blocks, **kwargs)) + self.main1 = [] + for _ in range(num_blocks): + self.main1.append(RSTB(**kwargs).cuda()) + + main2 = [] + main2 += [Rearrange('n c d h w -> n d h w c'), + nn.LayerNorm(kwargs['dim']), + Rearrange('n d h w c -> n d c h w')] + + self.main2 = nn.Sequential(*main2) + + def forward(self, x): + """ + Forward function for RSTBWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, t, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, t, out_channels, h, w) + """ + + + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.init_conv(x) + + x = rearrange(x, 'n c d h w -> n d h w c') + x = self.init_norm(x) + x = rearrange(x, 'n d h w c -> n c d h w') + + for i in range(len(self.main1)): + x = self.main1[i](x) + x = self.main2(x) + + return x + +class Upsample(nn.Module): + '''Upsample module for video SR. + + Args: + scale (int): Scale factor. Supported scales: 4. + num_feat (int): Channel number of intermediate features. + ''' + + def __init__(self, scale, num_feat, **kwargs): + super(Upsample, self).__init__() + + assert LooseVersion(torch.__version__) >= LooseVersion('1.8.1'), \ + 'PyTorch version >= 1.8.1 to support 5D PixelShuffle.' + + self.feat1 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.feat2 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.feat3 = nn.Conv3d(num_feat, 4 * num_feat, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + + self.upsample1 = nn.PixelShuffle(2) + self.upsample2 = nn.PixelShuffle(2) + self.upsample3 = nn.PixelShuffle(2) + + self.lrelu1 = nn.LeakyReLU(negative_slope=0.1) + self.lrelu2 = nn.LeakyReLU(negative_slope=0.1) + self.lrelu3 = nn.LeakyReLU(negative_slope=0.1) + + self.final = nn.Conv3d(num_feat, 1, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + + def forward(self, x): + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.feat1(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample1(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu1(x) + x = self.feat2(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample2(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu2(x) + x = self.feat3(x) + x = rearrange(x, 'n c d h w -> n d c h w') + x = self.upsample3(x) + x = rearrange(x, 'n d c h w -> n c d h w') + x = self.lrelu3(x) + + x = self.final(x) + x = rearrange(x, 'n c d h w -> n d c h w') + + return x + +class GaussianDiffusion(nn.Module): + def __init__( + self, + feat_ext, + feat_up, + backbone, + deform_align, + recon, + spynet, + *, + image_size, + timesteps = 1200, + sampling_timesteps = None, + loss_type = 'l1', + objective = 'pred_noise', + beta_schedule = 'sigmoid', + schedule_fn_kwargs = dict(), + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1, + ddim_sampling_eta = 0., + auto_normalize = True + ): + super(GaussianDiffusion, self).__init__() + self.clip_size = 2 + self.feat_ext = feat_ext + self.feat_up = feat_up + + self.backbone = backbone + + self.deform_align = deform_align + + self.recon = recon + + self.spynet = spynet + + self.channels = self.feat_ext.in_channels + + self.image_size = image_size + + self.loss_type = loss_type + + self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity + self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity + + @property + def loss_fn(self): + if self.loss_type == 'l1': + return F.l1_loss + elif self.loss_type == 'l2': + return F.mse_loss + elif self.loss_type == 'charbonnier': + return CharbonnierLoss() + else: + raise ValueError(f'invalid loss type {self.loss_type}') + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name, updated_flows=None): + """Propagate the latent clip features throughout the sequence. + + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, clip_size, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + updated_flows dict(list[tensor]): Each component is a list of updated + optical flows with shape (n, clip_size, 2, h, w). + + Return: + dict(list[tensor]): A dictionary containing all the propagated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + if 'backward' in module_name: + flow_idx = range(0, t + 1)[::-1] + clip_idx = range(0, (t + 1) // self.clip_size)[::-1] + else: + flow_idx = range(-1, t) + clip_idx = range(0, (t + 1) // self.clip_size) + + if '_1' in module_name: + updated_flows[f'{module_name}_n1'] = [] + updated_flows[f'{module_name}_n2'] = [] + + feat_prop = torch.zeros_like(feats['shallow'][0])#.cuda() + + last_key = list(feats)[-2] + + for i in range(0, len(clip_idx)): + idx_c = clip_idx[i] + if i > 0: + if '_1' in module_name: + flow_n01 = flows[:, flow_idx[self.clip_size * i - 1], :, :, :] + flow_n12 = flows[:, flow_idx[self.clip_size * i], :, :, :] + flow_n23 = flows[:, flow_idx[self.clip_size * i + 1], :, :, :] + flow_n02 = flow_n12 + flow_warp(flow_n01, flow_n12.permute(0, 2, 3, 1)) + flow_n13 = flow_n23 + flow_warp(flow_n12, flow_n23.permute(0, 2, 3, 1)) + flow_n03 = flow_n23 + flow_warp(flow_n02, flow_n23.permute(0, 2, 3, 1)) + flow_n1 = torch.stack([flow_n02, flow_n13], 1) + flow_n2 = torch.stack([flow_n12, flow_n03], 1) + else: + module_name_old = module_name.replace('_2', '_1') + flow_n1 = updated_flows[f'{module_name_old}_n1'][i - 1] + flow_n2 = updated_flows[f'{module_name_old}_n2'][i - 1] + + + if 'backward' in module_name: + feat_q = feats[last_key][idx_c].flip(1) + feat_k = feats[last_key][clip_idx[i - 1]].flip(1) + else: + feat_q = feats[last_key][idx_c] + feat_k = feats[last_key][clip_idx[i - 1]] + + feat_prop_warped1 = flow_warp(feat_prop.flatten(0, 1), + flow_n1.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + feat_prop_warped2 = flow_warp(feat_prop.flip(1).flatten(0, 1), + flow_n2.permute(0, 1, 3, 4, 2).flatten(0, 1))\ + .view(n, feat_prop.shape[1], feat_prop.shape[2], h, w) + + if '_1' in module_name: + feat_prop, flow_n1, flow_n2 = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + True) + updated_flows[f'{module_name}_n1'].append(flow_n1) + updated_flows[f'{module_name}_n2'].append(flow_n2) + else: + feat_prop = self.deform_align[module_name](feat_q, feat_k, feat_prop, + [feat_prop_warped1, feat_prop_warped2], + [flow_n1, flow_n2], + False) + + if 'backward' in module_name: + feat = [feats[k][idx_c].flip(1) for k in feats if k not in [module_name]] + [feat_prop] + else: + feat = [feats[k][idx_c] for k in feats if k not in [module_name]] + [feat_prop] + + #print(len(feat), feat[0].shape, feat[1].shape) + fp = self.backbone[module_name](torch.cat(feat, dim=2)) + #fp = self.backbone[module_name](torch.cat(feat, dim=1)) + feat_prop = feat_prop + fp + + feats[module_name].append(feat_prop) + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + feats[module_name] = [f.flip(1) for f in feats[module_name]] + + return feats + + def forward(self, lres, hres, *args, **kwargs): + + b, f, c, h, w, device, img_size = *hres.shape, hres.device, self.image_size + + assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(h//8, w//8), mode='bilinear'), '(b t) c h w -> b t c h w', b = b) + lres = torch.cat([lres, topo_low], dim = 2) + + lres = self.normalize(lres) + hres = self.normalize(hres) + + flows_forward, flows_backward = self.compute_flow(lres) + + feats = {} + ff = self.feat_ext(lres) + + feats['shallow'] = list(torch.chunk(ff, f // self.clip_size, dim = 1)) + + updated_flows = {} + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + if direction == 'backward': + flows = flows_backward + else: + flows = flows_forward if flows_forward is not None else flows_backward.flip(1) + + module_name = f'{direction}_{iter_}' + feats[module_name] = [] + + feats = self.propagate(feats, flows, module_name, updated_flows) + + feats['shallow'] = torch.cat(feats['shallow'], 1) + feats['backward_1'] = torch.cat(feats['backward_1'], 1) + feats['forward_1'] = torch.cat(feats['forward_1'], 1) + feats['backward_2'] = torch.cat(feats['backward_2'], 1) + feats['forward_2'] = torch.cat(feats['forward_2'], 1) + upsampled = torch.cat([feats[k] for k in feats], dim=2) + upsampled = self.recon(upsampled) + upsampled = self.feat_up(upsampled) + upsampled = upsampled + F.interpolate(lres[:,:,0:1,:,:], size = (1, h, w), mode = 'trilinear', align_corners = False) + + loss = self.loss_fn(upsampled, hres, reduction = 'none') + loss = einops.reduce(loss, 'b ... -> b (...)', 'mean') + + return loss.mean(), upsampled + +class Trainer(object): + def __init__( + self, + diffusion_model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + #augment_horizontal_flip = True, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 1, + #num_samples = 25, + eval_folder = './evaluate', + results_folder = './results', + #tensorboard_dir = './tensorboard', + val_num_of_batch = 2, + amp = False, + fp16 = False, + #fp16 = True, + split_batches = True, + #split_batches = False, + convert_image_to = None + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("climate", + init_kwargs={ + "wandb": { + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + #self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] + + self.model = diffusion_model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + self.image_size = diffusion_model.image_size + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) + + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, self.sched, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, self.sched, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = val_dl + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + print('loaded') + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + # c384_lgmin = np.load('data/only_precip/c384_lgmin.npy') + # c384_lgmax = np.load('data/only_precip/c384_lgmax.npy') + # c384_gmin = np.load('data/only_precip/c384_gmin.npy') + + # c48_lgmin = np.load('data/only_precip/c48_lgmin.npy') + # c48_lgmax = np.load('data/only_precip/c48_lgmax.npy') + # c48_gmin = np.load('data/only_precip/c48_gmin.npy') + + # c384_min = np.load('data/only_precip/c384_min.npy') + # c384_max = np.load('data/only_precip/c384_max.npy') + + # c48_min = np.load('data/only_precip/c48_min.npy') + # c48_max = np.load('data/only_precip/c48_max.npy') + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + if self.minipatch: + + x_st = randint(0, 36) + y_st = randint(0, 36) + lres = crop(lres, x_st, y_st, 12, 12) + hres = crop(hres, 8 * x_st, 8 * y_st, 96, 96) + + with self.accelerator.autocast(): + + loss, _ = self.model(lres, hres) + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + self.sched.step() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + bases, ress, flowss = [], [], [] + num_frames = 5 + img_size = 384 + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + loss, videos = self.model(lres, hres) + + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + del vids, vlosses, hr, lr + + + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] + + with torch.no_grad(): + + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + loss, videos = self.model(lres, hres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + #indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 75 samples per tile + indices = list(range(0, 3176 - (seq_len + 2), 250)) # deterministic, 325 samples per tile for seq_len of 25 + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + + loss, videos = self.model(lres, hres) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/isr_baseline.py b/projects/super_res/model/isr_baseline.py new file mode 100644 index 0000000000..ed59f3c013 --- /dev/null +++ b/projects/super_res/model/isr_baseline.py @@ -0,0 +1,568 @@ +from pathlib import Path +import os + +import numpy as np +import xarray as xr + +import torch +import wandb + +import piq +import pickle +import cv2 +from scipy.stats import wasserstein_distance + +from torch.optim import Adam +import torch.nn.functional as F + +from random import randint +from torch.optim.lr_scheduler import CosineAnnealingLR + +from tqdm.auto import tqdm +from ema_pytorch import EMA +import matplotlib as mpl +import matplotlib.pyplot as plt +from matplotlib.cm import ScalarMappable as smap + +from accelerate import Accelerator +from einops import rearrange, reduce + +def get_random_idx_with_difference(min_tx, max_tx, number_tx, diff): + times = [] + while len(times) < number_tx: + new_time = randint(min_tx, max_tx) + if all(abs(new_time - time) >= diff for time in times): + times.append(new_time) + return times + +def cycle(dl): + while True: + for data in dl: + yield data + +def exists(x): + return x is not None + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# trainer class + +class Trainer(object): + def __init__( + self, + model, + train_dl, + val_dl, + config, + *, + train_batch_size = 16, + gradient_accumulate_every = 1, + #augment_horizontal_flip = True, + train_lr = 1e-4, + train_num_steps = 100000, + ema_update_every = 1, + ema_decay = 0.995, + adam_betas = (0.9, 0.99), + save_and_sample_every = 10, + #num_samples = 25, + eval_folder = './evaluate', + results_folder = './results', + #tensorboard_dir = './tensorboard', + val_num_of_batch = 2, + amp = False, + fp16 = False, + #fp16 = True, + split_batches = True, + #split_batches = False, + convert_image_to = None + ): + super().__init__() + + self.accelerator = Accelerator( + split_batches = split_batches, + mixed_precision = 'fp16' if fp16 else 'no', + log_with = 'wandb', + ) + self.accelerator.init_trackers("climate", + init_kwargs={ + "wandb": { + "name": None, + } + }, + ) + self.config = config + self.accelerator.native_amp = amp + self.multi = config.data_config["multi"] + self.rollout = config.rollout + self.rollout_batch = config.rollout_batch + self.flow = config.data_config["flow"] + self.minipatch = config.data_config["minipatch"] + self.logscale = config.data_config["logscale"] + + self.model = model + + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.gradient_accumulate_every = gradient_accumulate_every + + self.train_num_steps = train_num_steps + + self.val_num_of_batch = val_num_of_batch + + # optimizer + + self.opt = Adam(model.parameters(), lr = train_lr, betas = adam_betas) + self.sched = CosineAnnealingLR(self.opt, train_num_steps, 5e-7) + + # for logging results in a folder periodically + + if self.accelerator.is_main_process: + self.ema = EMA(model, beta = ema_decay, update_every = ema_update_every) + + self.results_folder = Path(results_folder) + + self.results_folder.mkdir(exist_ok=True, parents=True) + + self.eval_folder = Path(eval_folder) + + self.eval_folder.mkdir(exist_ok=True, parents=True) + + # step counter state + + self.step = 0 + + # prepare model, dataloader, optimizer with accelerator + + self.model, self.opt, train_dl, val_dl = self.accelerator.prepare(self.model, self.opt, train_dl, val_dl) + self.train_dl = cycle(train_dl) + self.val_dl = val_dl + + def save(self, milestone): + if not self.accelerator.is_local_main_process: + return + + data = { + 'step': self.step, + 'model': self.accelerator.get_state_dict(self.model), + 'opt': self.opt.state_dict(), + 'ema': self.ema.state_dict(), + 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, + #'version': __version__ + } + + torch.save(data, str(self.results_folder / f'qmodel-{milestone%3}.pt')) + + def load(self, milestone): + accelerator = self.accelerator + device = accelerator.device + + data = torch.load(str(self.results_folder / f'qmodel-{milestone}.pt'), map_location=device) + + model = self.accelerator.unwrap_model(self.model) + model.load_state_dict(data['model']) + + self.step = data['step'] + #self.opt.load_state_dict(data['opt']) + self.ema.load_state_dict(data['ema']) + + #if 'version' in data: + # print(f"loading from version {data['version']}") + + if exists(self.accelerator.scaler) and exists(data['scaler']): + self.accelerator.scaler.load_state_dict(data['scaler']) + + def train(self): + + accelerator = self.accelerator + device = accelerator.device + + cmap = mpl.colormaps['RdBu_r'] + fcmap = mpl.colormaps['gray_r'] + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + c384_log_chl = pickle.load(f) + + c384_lgmin = c384_log_chl["PRATEsfc"]['min'] + c384_lgmax = c384_log_chl["PRATEsfc"]['max'] + c48_lgmin = c48_log_chl["PRATEsfc_coarse"]['min'] + c48_lgmax = c48_log_chl["PRATEsfc_coarse"]['max'] + + c384_min = c384_chl["PRATEsfc"]['min'] + c384_max = c384_chl["PRATEsfc"]['max'] + c48_min = c48_chl["PRATEsfc_coarse"]['min'] + c48_max = c48_chl["PRATEsfc_coarse"]['max'] + + c384_gmin = c384_min + c48_gmin = c48_min + + with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: + + while self.step < self.train_num_steps: + + total_loss = 0. + + for _ in range(self.gradient_accumulate_every): + + data = next(self.train_dl) + lres = data['LR'].to(device) + hres = data['HR'].to(device) + + with self.accelerator.autocast(): + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = 7) + + ures = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + loss = F.mse_loss(ures, rearrange(hres, 'b t c h w -> (b t) c h w'), reduction = 'none') + loss = reduce(loss, 'b ... -> b (...)', 'mean') + loss = loss.mean() + + loss = loss / self.gradient_accumulate_every + total_loss += loss.item() + + self.accelerator.backward(loss) + + accelerator.clip_grad_norm_(self.model.parameters(), 1.0) + pbar.set_description(f'loss: {total_loss:.4f}') + + accelerator.log({"loss": total_loss}, step = self.step) + + accelerator.wait_for_everyone() + + self.opt.step() + self.opt.zero_grad() + + accelerator.wait_for_everyone() + + self.step += 1 + if accelerator.is_main_process: + self.ema.to(device) + self.ema.update() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema.ema_model.eval() + + with torch.no_grad(): + + vlosses = [] + vids = [] + hr = [] + lr = [] + num_frames = 5 + img_size = 384 + + for i, batch in enumerate(self.val_dl): + + lres = batch['LR'].to(device) + hres = batch['HR'].to(device) + + if i >= self.val_num_of_batch: + break + + topo = hres[:, :, 1:2, :, :] + hres = hres[:, :, 0:1, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = 7) + + ures = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')) + loss = F.mse_loss(ures, rearrange(hres, 'b t c h w -> (b t) c h w'), reduction = 'none') + + videos = rearrange(ures, '(b t) c h w -> b t c h w', t = 7) + + vids.append(videos) + vlosses.append(loss) + hr.append(hres) + lr.append(lres) + + videos = torch.cat(vids, dim = 0) + vloss = torch.stack(vlosses, dim = 0).mean() + #self.sched.step(vloss) + hres = torch.cat(hr, dim = 0) + lres = torch.cat(lr, dim = 0) + del vids, vlosses, hr, lr + + lres = lres[:, :, 0:1, :, :] + hres = hres[:, :, 0:1, :, :] + + if not self.logscale: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_max - c384_min) + c384_min + output = videos.detach().cpu().numpy() * (c384_max - c384_min) + c384_min + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_max - c48_min) + c48_min + + else: + target = hres[:,:,:,:,:].detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + output = videos.detach().cpu().numpy() * (c384_lgmax - c384_lgmin) + c384_lgmin + coarse = lres[:,:,:,:,:].detach().cpu().numpy() * (c48_lgmax - c48_lgmin) + c48_lgmin + + if self.logscale: + target = np.exp(target) + c384_gmin - 1e-14 + output = np.exp(output) + c384_gmin - 1e-14 + coarse = np.exp(coarse) + c48_gmin - 1e-14 + + ssim_index = piq.ssim(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + gmsd_index = piq.gmsd(torch.from_numpy(target).view(-1, 1, 384, 384), torch.from_numpy(output).view(-1, 1, 384, 384).clamp(0., 1.), data_range=1., reduction='none') + + nn_upscale = np.repeat(np.repeat(coarse, 8, axis = 3), 8, axis = 4) + diff_output = (output - nn_upscale).flatten() + diff_target = (target - nn_upscale).flatten() + vmin = min(diff_output.min(), diff_target.min()) + vmax = max(diff_output.max(), diff_target.max()) + bins = np.linspace(vmin, vmax, 100 + 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + ax.hist( + diff_output, bins=bins, alpha=0.5, label="Output", histtype="step", density=True + ) + ax.hist( + diff_target, bins=bins, alpha=0.5, label="Target", histtype="step", density=True + ) + ax.set_xlim(vmin, vmax) + ax.legend() + ax.set_ylabel("Density") + ax.set_yscale("log") + + output1 = output.flatten() + target1 = target.flatten() + rmse = np.sqrt(np.mean((output1 - target1)**2)) + pscore = np.abs(np.percentile(output1, 99.999) - np.percentile(target1, 99.999)) + vmin1 = min(output1.min(), target1.min()) + vmax1 = max(output1.max(), target1.max()) + bins1 = np.linspace(vmin1, vmax1, 100 + 1) + #histo = np.histogram(output1, bins=bins1, density=True)[0].ravel().astype('float32') + #histt = np.histogram(target1, bins=bins1, density=True)[0].ravel().astype('float32') + count_o, bin_o = np.histogram(output1, bins=bins1, density=True) + count_t, bin_t = np.histogram(target1, bins=bins1, density=True) + histo = count_o.ravel().astype('float32') + histt = count_t.ravel().astype('float32') + distchisqr = cv2.compareHist(histo, histt, cv2.HISTCMP_CHISQR) + distinter = cv2.compareHist(histo, histt, cv2.HISTCMP_INTERSECT) + distkl = cv2.compareHist(histo, histt, cv2.HISTCMP_KL_DIV) + distemd = wasserstein_distance(output1, target1) + + fig1, ax1 = plt.subplots(1, 1, figsize=(6, 4)) + ax1.hist( + #output1, bins=bins1, alpha=0.5, label="Output", histtype="step", density=True + bin_o[:-1], bins=bin_o, weights = count_o, alpha=0.5, label="Output", histtype="step"#, density=True + ) + ax1.hist( + #target1, bins=bins1, alpha=0.5, label="Target", histtype="step", density=True + bin_t[:-1], bins=bin_t, weights = count_t, alpha=0.5, label="Target", histtype="step"#, density=True + ) + ax1.set_xlim(vmin1, vmax1) + ax1.legend() + ax1.set_ylabel("Density") + ax1.set_yscale("log") + + if self.logscale: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos.clamp(0.0, 1.0)[0:1,:,0:1,:,:].repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + + else: + + accelerator.log({"true_high": wandb.Video((hres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_low": wandb.Video((lres[0:1,2:,0:1,:,:].repeat(1,1,3,1,1).cpu().numpy()*255).astype(np.uint8))}, step=self.step) + accelerator.log({"samples": wandb.Video((videos[0:1,:,:,:,:].clamp(0.0, 1.0).repeat(1,1,3,1,1).detach().cpu().numpy()*255).astype(np.uint8))}, step=self.step) + target = np.log(target - c384_gmin + 1e-14) + output = np.log(output - c384_gmin + 1e-14) + coarse = np.log(coarse - c48_gmin + 1e-14) + target = (target - c384_lgmin) / (c384_lgmax - c384_lgmin) + output = (output - c384_lgmin) / (c384_lgmax - c384_lgmin) + coarse = (coarse - c48_lgmin) / (c48_lgmax - c48_lgmin) + accelerator.log({"true_loghigh": wandb.Video((np.repeat(target[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"true_loglow": wandb.Video((np.repeat(coarse[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + accelerator.log({"logsamples": wandb.Video((np.repeat(output[0:1,:,:,:,:], 3, axis=-3)*255).astype(np.uint8))}, step=self.step) + + accelerator.log({"difference_histogram": wandb.Image(fig, mode = 'RGB')}, step=self.step) + accelerator.log({"histogram": wandb.Image(fig1, mode = 'RGB')}, step=self.step) + accelerator.log({"ssim": ssim_index.mean()}, step=self.step) + accelerator.log({"gmsd": gmsd_index.mean()}, step=self.step) + accelerator.log({"rmse": rmse}, step=self.step) + accelerator.log({"pscore": pscore}, step=self.step) + accelerator.log({"distchisqr": distchisqr}, step=self.step) + accelerator.log({"distinter": distinter}, step=self.step) + accelerator.log({"distkl": distkl}, step=self.step) + accelerator.log({"distemd": distemd}, step=self.step) + accelerator.log({"vloss": vloss}, step=self.step) + accelerator.log({"lr": self.opt.param_groups[0]['lr']}, step=self.step) + + milestone = self.step // self.save_and_sample_every + + self.save(milestone) + + pbar.update(1) + + accelerator.print('training complete') + + def sample(self): + + accelerator = self.accelerator + device = accelerator.device + + self.ema.ema_model.eval() + + PATH = "/extra/ucibdl0/shared/data/fv3gfs" + XX = xr.open_zarr(f"{PATH}/c48_precip_plus_more_ave/0011/sfc_8xdaily_ave_coarse.zarr") + XX_ = xr.open_zarr(f"{PATH}/c48_atmos_ave/0011/atmos_8xdaily_ave_coarse.zarr") + yy = xr.open_zarr(f"{PATH}/c384_precip_ave/0011/sfc_8xdaily_ave.zarr") + topot = xr.open_zarr(f"{PATH}/c384_topo/0011/atmos_static.zarr") + + with open("data/ensemble_c48_trainstats/chl.pkl", 'rb') as f: + + c48_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/atm_chl.pkl", 'rb') as f: + + c48_atm_chl = pickle.load(f) + + with open("data/ensemble_c48_trainstats/log_chl.pkl", 'rb') as f: + + c48_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/chl.pkl", 'rb') as f: + + c384_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/log_chl.pkl", 'rb') as f: + + c384_log_chl = pickle.load(f) + + with open("data/ensemble_c384_trainstats/topo.pkl", 'rb') as f: + + c384_topo = pickle.load(f) + + if self.multi: + + c48_channels = ["PRATEsfc_coarse", "UGRD10m_coarse", "VGRD10m_coarse", "TMPsfc_coarse", "CPRATsfc_coarse", "DSWRFtoa_coarse"] + c48_channels_atmos = ["ps_coarse", "u700_coarse", "v700_coarse", "vertically_integrated_liq_wat_coarse", "vertically_integrated_sphum_coarse"] + c384_channels = ["PRATEsfc"] + + else: + + c48_channels = ["PRATEsfc_coarse"] + c384_channels = ["PRATEsfc"] + + with torch.no_grad(): + + for tile in range(6): + + if self.rollout == 'full': + + seq_len = self.rollout_batch + st = 0 + en = seq_len + 2 + count = 0 + + while en < 3176: + + print(tile, st) + + X = XX.isel(time = slice(st, en), tile = tile) + X_ = XX_.isel(time = slice(st, en), tile = tile) + y = yy.isel(time = slice(st, en), tile = tile) + + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + topo = hres[:, :, 1:2, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = seq_len + 2) + + videos = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')).unsqueeze(0) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) + count += 1 + + st += seq_len + en += seq_len + + if self.rollout == 'partial': + + seq_len = self.rollout_batch + #indices = get_random_idx_with_difference(0, 3176 - (seq_len + 2), 75 // seq_len, seq_len + 2) # 75 samples per tile + indices = list(range(0, 3176 - (seq_len + 2), 250)) # deterministic, 325 samples per tile for seq_len of 25 + + for count, st in enumerate(indices): + + print(tile, count) + + X = XX.isel(time = slice(st, st+(seq_len+2)), tile = tile) + X_ = XX_.isel(time = slice(st, st+(seq_len+2)), tile = tile) + y = yy.isel(time = slice(st, st+(seq_len+2)), tile = tile) + + X = np.stack([X[channel].values for channel in c48_channels], axis = 1) + X_ = np.stack([X_[channel].values for channel in c48_channels_atmos], axis = 1) + y = np.stack([y[channel].values for channel in c384_channels], axis = 1) + topo = topot.isel(tile = tile) + topo = topo['zsurf'].values + topo = np.repeat(topo.reshape((1,1,384,384)), seq_len + 2, axis = 0) + + X[:,0:1,:,:] = np.log(X[:,0:1,:,:] - c48_chl["PRATEsfc_coarse"]['min'] + 1e-14) + y = np.log(y - c384_chl["PRATEsfc"]['min'] + 1e-14) + X[:,0:1,:,:] = (X[:,0:1,:,:] - c48_log_chl["PRATEsfc_coarse"]['min']) / (c48_log_chl["PRATEsfc_coarse"]['max'] - c48_log_chl["PRATEsfc_coarse"]['min']) + y = (y - c384_log_chl["PRATEsfc"]['min']) / (c384_log_chl["PRATEsfc"]['max'] - c384_log_chl["PRATEsfc"]['min']) + + for i in range(1, X.shape[1]): + + X[:,i,:,:] = (X[:,i,:,:] - c48_chl[c48_channels[i]]['min']) / (c48_chl[c48_channels[i]]['max'] - c48_chl[c48_channels[i]]['min']) + + for i in range(X_.shape[1]): + + X_[:,i,:,:] = (X_[:,i,:,:] - c48_atm_chl[c48_channels_atmos[i]]['min']) / (c48_atm_chl[c48_channels_atmos[i]]['max'] - c48_atm_chl[c48_channels_atmos[i]]['min']) + + topo = (topo - c384_topo["zsurf"]['min']) / (c384_topo["zsurf"]['max'] - c384_topo["zsurf"]['min']) + + X = np.concatenate((X, X_), axis = 1) + y = np.concatenate((y, topo), axis = 1) + + lres = torch.from_numpy(X).unsqueeze(0).to(device) + hres = torch.from_numpy(y).unsqueeze(0).to(device) + topo = hres[:, :, 1:2, :, :] + topo_low = rearrange(F.interpolate(rearrange(topo, 'b t c h w -> (b t) c h w'), size=(48, 48), mode='bilinear'), '(b t) c h w -> b t c h w', t = seq_len + 2) + + videos = self.model(rearrange(torch.cat((lres, topo_low), dim = 2), 'b t c h w -> (b t) c h w')).unsqueeze(0) + + torch.save(videos, os.path.join(self.eval_folder) + "/gen_{}_{}.pt".format(tile, count)) + torch.save(hres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_hr_{}_{}.pt".format(tile, count)) + torch.save(lres[:,:,0:1,:,:], os.path.join(self.eval_folder) + "/truth_lr_{}_{}.pt".format(tile, count)) \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn.py b/projects/super_res/model/op/deform_attn.py new file mode 100644 index 0000000000..55da954230 --- /dev/null +++ b/projects/super_res/model/op/deform_attn.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from einops.layers.torch import Rearrange +from distutils.version import LooseVersion +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +deform_attn_ext = load( + 'deform_attn', + sources=[ + os.path.join(module_path, 'deform_attn_ext.cpp'), + os.path.join(module_path, 'deform_attn_cuda_pt110.cpp' if LooseVersion(torch.__version__) >= LooseVersion( + '1.10.0') else 'deform_attn_cuda_pt109.cpp'), + os.path.join(module_path, 'deform_attn_cuda_kernel.cu'), +], +) + + +class Mlp(nn.Module): + """ Multilayer perceptron. + + Args: + x: (B, D, H, W, C) + + Returns: + x: (B, D, H, W, C) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class DeformAttnFunction(Function): + + @staticmethod + def forward(ctx, + q, + kv, + offset, + kernel_h, + kernel_w, + stride=1, + padding=0, + dilation=1, + attention_heads=1, + deformable_groups=1, + clip_size=1): + ctx.kernel_h = kernel_h + ctx.kernel_w = kernel_w + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.attention_heads = attention_heads + ctx.deformable_groups = deformable_groups + ctx.clip_size = clip_size + if q.requires_grad or kv.requires_grad or offset.requires_grad: + ctx.save_for_backward(q, kv, offset) + output = q.new_empty(q.shape) + ctx._bufs = [q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0), q.new_empty(0)] + deform_attn_ext.deform_attn_forward(q, kv, offset, output, + ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + q, kv, offset = ctx.saved_tensors + grad_q = torch.zeros_like(q) + grad_kv = torch.zeros_like(kv) + grad_offset = torch.zeros_like(offset) + deform_attn_ext.deform_attn_backward(q, kv, offset, ctx._bufs[0], ctx._bufs[1], ctx._bufs[2], ctx._bufs[3], ctx._bufs[4], + grad_q, grad_kv, grad_offset, + grad_output, ctx.kernel_h, ctx.kernel_w, ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.attention_heads, ctx.deformable_groups, ctx.clip_size) + + return (grad_q, grad_kv, grad_offset, None, None, None, None, None, None, None, None) + + +deform_attn = DeformAttnFunction.apply + + +class DeformAttn(nn.Module): + + def __init__(self, + in_channels, + out_channels, + attention_window=[3, 3], + deformable_groups=12, + attention_heads=12, + clip_size=1): + super(DeformAttn, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_h = attention_window[0] + self.kernel_w = attention_window[1] + self.attn_size = self.kernel_h * self.kernel_w + self.deformable_groups = deformable_groups + self.attention_heads = attention_heads + self.clip_size = clip_size + self.stride = 1 + self.padding = self.kernel_h//2 + self.dilation = 1 + + self.proj_q = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_k = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.proj_v = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + nn.Linear(self.in_channels, self.in_channels), + Rearrange('n d h w c -> n d c h w')) + self.mlp = nn.Sequential(Rearrange('n d c h w -> n d h w c'), + Mlp(self.in_channels, self.in_channels * 2), + Rearrange('n d h w c -> n d c h w')) + + def forward(self, q, k, v, offset): + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v + + +class DeformAttnPack(DeformAttn): + """A Deformable Attention Encapsulation that acts as normal attention layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + attention_window (int or tuple[int]): Attention window size. Default: [3, 3]. + attention_heads (int): Attention head number. Default: 12. + deformable_groups (int): Deformable offset groups. Default: 12. + clip_size (int): clip size. Default: 2. + """ + + def __init__(self, *args, **kwargs): + super(DeformAttnPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels * (1 + self.clip_size), + self.clip_size * self.deformable_groups * self.attn_size * 2, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + dilation=(1, 1), + bias=True) + self.init_weight() + + def init_weight(self): + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, q, k, v): + out = self.conv_offset(torch.cat([q.flatten(1, 2), k.flatten(1, 2)], 1)) + o1, o2 = torch.chunk(out, 2, dim=1) + offset = torch.cat((o1, o2), dim=1) + + q = self.proj_q(q) + kv = torch.cat([self.proj_k(k), self.proj_v(v)], 2) + v = deform_attn(q, kv, offset, self.kernel_h, self.kernel_w, self.stride, self.padding, self.dilation, + self.attention_heads, self.deformable_groups, self.clip_size) + v = v + self.mlp(v) + return v \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_kernel.cu b/projects/super_res/model/op/deform_attn_cuda_kernel.cu new file mode 100644 index 0000000000..6f1ccc2c91 --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_pt109.cpp b/projects/super_res/model/op/deform_attn_cuda_pt109.cpp new file mode 100644 index 0000000000..46ef081a8f --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_pt109.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); +// // for PyTorch 1.10.1 +// const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// for PyTorch 1.9.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 +// // for PyTorch 1.10.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_cuda_pt110.cpp b/projects/super_res/model/op/deform_attn_cuda_pt110.cpp new file mode 100644 index 0000000000..0dd7816d80 --- /dev/null +++ b/projects/super_res/model/op/deform_attn_cuda_pt110.cpp @@ -0,0 +1,219 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deform_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deform_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deform_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deform_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ){ + TORCH_CHECK(kv.is_contiguous(), "input tensor has to be contiguous"); + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + output = output.view({batch, attn_head, attn_dim, height, width}).zero_(); + + // resize temporary columns and attns + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // batch x clip_size x (deform_group*attn_size) x (heightxwidth) + + for (int b = 0; b < batch; b++) { // 0->2,1->2, or, 1->3,0->3 // todo: refer to deformable_im2col_cuda and use `im2col_step` to speed up + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + // do attention + output[b] = at::matmul(attns, columns[1].transpose(2, 3)) // (attn_head x (height*width) x 1 x attn_dim) + .transpose(1, 3).view({attn_head, attn_dim, height, width}); // (attn_head x attn_dim x height x width) + + // resize columns back for next batch + columns = columns.view({2, attn_head, area, attn_dim, clip_size , attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 3); // clip_size x (attn_head*attn_dim*attn_size) x (height*width) + } + + output = output.view({batch, channels, height, width}); +} + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ){ + at::DeviceGuard guard(kv.device()); + + const int batch = q.size(0); + const int kv_channels = kv.size(2); + const int channels = kv.size(2) / 2; + const int height = kv.size(3); + const int width = kv.size(4); + const int area = height * width; + + const int attn_dim = channels / attn_head; + const int attn_size = kernel_h * kernel_w; + const float attn_scale = pow(attn_dim, -0.5); + // for PyTorch 1.10.1 + const at::ScalarType dtype = kv.scalar_type(); + + // resize inputs + q = q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}) * attn_scale; // batch x attn_head x (height*width) x 1 x attn_dim + offset = offset.view({batch, clip_size, offset.size(1) / clip_size, area}); // batch x clip_size x (deform_groupxattn_sizex2) x (heightxwidht) + + grad_q = grad_q.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + grad_offset = grad_offset.view({batch, clip_size, grad_offset.size(1) / clip_size, area}); + grad_output = grad_output.view({batch, 1, attn_head, attn_dim, area}).permute({0, 2, 4, 1, 3}); + + // resize temporary columns, attns and grad_attns (we further need grad_attns in backward propagation because attn@V are interdependent. + columns = at::zeros({clip_size, kv_channels * attn_size, area}, q.options()); + attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + mask_ones = at::ones({deform_group * attn_size, area}, q.options()); // (deform_group*attn_size) x (heightxwidth) + grad_attns = at::zeros({attn_head, area, 1, clip_size * attn_size}, q.options()); + grad_mask_ones = at::zeros({deform_group * attn_size, area}, q.options()); // not returned + + + for (int b = 0; b < batch; b++) { + // recalculate columns and attns + // grid_sample q and k according to offset + for (int n = 0; n < clip_size; n++) { + modulated_deformable_im2col_cuda( + kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, columns[n]); + } + + columns = columns.view({clip_size, 2, attn_head, attn_dim, attn_size, area}) + .permute({1, 2, 5, 3, 0, 4}).flatten(4); // kv x attn_head x (height*width) x attn_dim x (clip_size*attn_size) + + // calculate attention, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + attns = at::matmul(q[b], columns[0]) + .softmax(-1); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. attns, (attn_head x (height*width) x 1 x attn_dim) @ (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + grad_attns = at::matmul(grad_output[b], columns[1]); // (attn_head x (height*width) x 1 x (clip_size*attn_size)) + + // gradient w.r.t. sampled_v, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[1] = at::matmul(grad_output[b].transpose(2, 3), attns); // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + // gradient w.r.t. attns_before_softmax +// // for PyTorch 1.9.1 +// grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, grad_attns); // todo: it seems pt191 has different interface as pt110 + // for PyTorch 1.10.1 + grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype); + + // gradient w.r.t. q, (attn_head x (height*width) x 1 x (clip_size*attn_size)) @ (attn_head x (height*width) x (clip_size*attn_size) x attn_dim) + grad_q[b] = at::matmul(grad_attns, columns[0].transpose(2, 3)) * attn_scale; // (attn_head x (height*width) x 1 x attn_dim) + + // gradient w.r.t. sampled_k, (attn_head x (height*width) x attn_dim x 1) @ (attn_head x (height*width) x 1 x (clip_size*attn_size)) + columns[0] = at::matmul(q[b].transpose(2, 3), grad_attns) * attn_scale; // (attn_head x (height*width) x attn_dim x (clip_size*attn_size)) + + columns = columns.view({2, attn_head, area, attn_dim, clip_size, attn_size}) + .permute({4, 0, 1, 3, 5, 2}) // clip_size x 2 x attn_head x attn_dim x attn_size x (height*width) + .flatten(1, 4); // clip_size x (2*attn_head*attn_dim*attn_size) x (height*width) + + for (int n = 0; n < clip_size; n++) { + // gradient w.r.t. input coordinate data (grad_offset and grad_mask_ones) + modulated_deformable_col2im_coord_cuda( + columns[n], kv[b/clip_size][(n+b)%clip_size], offset[b][n], mask_ones, 1, kv_channels, height, width, + height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deform_group, grad_offset[b][n], + grad_mask_ones); + + // gradient w.r.t. kv + modulated_deformable_col2im_cuda( + columns[n], offset[b][n], mask_ones, 1, kv_channels, height, width, height, + width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deform_group, grad_kv[b/clip_size][(n+b)%clip_size]); // the grad is accumulated + } + } + + // resize gradidents back + grad_q = grad_q.transpose(2, 4).view({batch, channels, height, width}); // batch x (attn_headxattn_dim) x height x width + grad_offset = grad_offset.flatten(1, 2); + grad_output = grad_output.permute({0, 1, 4, 3, 2}).view({batch, channels, height, width}); +} \ No newline at end of file diff --git a/projects/super_res/model/op/deform_attn_ext.cpp b/projects/super_res/model/op/deform_attn_ext.cpp new file mode 100644 index 0000000000..a09d85851a --- /dev/null +++ b/projects/super_res/model/op/deform_attn_ext.cpp @@ -0,0 +1,75 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA + +void deform_attn_cuda_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ); + +void deform_attn_cuda_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ); +#endif + +void deform_attn_forward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor attns, at::Tensor mask_ones, int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int attn_head, const int deform_group, const int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_forward(q, kv, + offset, output, columns, attns, mask_ones, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + +void deform_attn_backward( + at::Tensor q, at::Tensor kv, at::Tensor offset, at::Tensor columns, + at::Tensor attns, at::Tensor mask_ones, at::Tensor grad_attns, at::Tensor grad_mask_ones, at::Tensor grad_q, at::Tensor grad_kv, + at::Tensor grad_offset, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int attn_head, int deform_group, int clip_size + ) { + if (q.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_attn_cuda_backward(q, kv, + offset, columns, attns, mask_ones, grad_attns, grad_mask_ones, grad_q, grad_kv, grad_offset, + grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, attn_head, deform_group, clip_size); +#else + AT_ERROR("modulated deform attn is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform attn is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_attn_forward", + &deform_attn_forward, + "deform attn forward"); + m.def("deform_attn_backward", + &deform_attn_backward, + "deform attn backward"); +} \ No newline at end of file diff --git a/projects/super_res/sampler.py b/projects/super_res/sampler.py index dd5a9eeb0b..69b3185d5a 100644 --- a/projects/super_res/sampler.py +++ b/projects/super_res/sampler.py @@ -7,9 +7,12 @@ def main(): if config.data_config["multi"]: - in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise - in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) - in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo else: diff --git a/projects/super_res/sampler_isr.py b/projects/super_res/sampler_isr.py new file mode 100644 index 0000000000..20c2a71992 --- /dev/null +++ b/projects/super_res/sampler_isr.py @@ -0,0 +1,38 @@ +import os + +from model.isr_baseline import Trainer +from model.network_swinir import SwinIR +from config_isr_infer import config + +def main(): + model = SwinIR(upscale=8, in_chans=12, out_chans=1, img_size=48, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, + num_heads=[8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv').cuda() + + trainer = Trainer( + model, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), + ) + + trainer.load(config.milestone) + + trainer.sample() + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/sampler_rvrt_full.py b/projects/super_res/sampler_rvrt_full.py new file mode 100644 index 0000000000..f5ec3f2a4c --- /dev/null +++ b/projects/super_res/sampler_rvrt_full.py @@ -0,0 +1,103 @@ +import os +from torch import nn +from model.denoising_diffusion_rvrt_full import RSTBWithInputConv, Upsample, GuidedDeformAttnPack, GaussianDiffusion, SpyNet, Trainer +from config_rvrt_full_infer import config + +recon = RSTBWithInputConv( + in_channels = 5 * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_ext = RSTBWithInputConv( + in_channels = config.data_config["img_channel"]+11, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_up = Upsample( + scale = 8, + num_feat = config.dim, + in_channels = config.data_config["img_channel"] +).cuda() + +spynet = SpyNet('./spynet').cuda() + +backbone = nn.ModuleDict() +deform_align = nn.ModuleDict()\ + +modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + +for i, module in enumerate(modules): + # deformable attention + deform_align[module] = GuidedDeformAttnPack(config.dim, + config.dim, + attention_window=[3, 3], + attention_heads=6, + deformable_groups=6, + clip_size=2, + max_residue_magnitude=10).cuda() + + # feature propagation + backbone[module] = RSTBWithInputConv( + in_channels = (2 + i) * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 2, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (2,8,8) + ).cuda() + +diffusion = GaussianDiffusion( + feat_ext = feat_ext, + feat_up = feat_up, + backbone = backbone, + deform_align = deform_align, + recon = recon, + spynet = spynet, + image_size = config.data_config["img_size"], + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective +).cuda() + +trainer = Trainer( + diffusion, + None, + None, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config +) + +trainer.load(config.milestone) + +trainer.sample() \ No newline at end of file diff --git a/projects/super_res/trainer_focal.py b/projects/super_res/trainer_focal.py new file mode 100755 index 0000000000..82cb803442 --- /dev/null +++ b/projects/super_res/trainer_focal.py @@ -0,0 +1,90 @@ +import os + +from model.autoreg_diffusion_mod_focal import Unet, Flow, GaussianDiffusion, Trainer +from data.load_data import load_data +from config_focal import config + +def main(): + + if config.data_config["multi"]: + + # in_ch_model = 2 * config.data_config["img_channel"] + 4 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + # in_ch_flow = 3 * (config.data_config["img_channel"] + 4 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + # in_ch_isr = config.data_config["img_channel"] + 4 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + in_ch_model = 2 * config.data_config["img_channel"] + 10 + 1 # all channels plus noise : (1 + 4 + 1) + 1 : (precip + multi + topo) + noise + in_ch_flow = 3 * (config.data_config["img_channel"] + 10 + 1) # all channels from current low res and past two high res : 3 * (1 + 4 + 1) : 3 * (precip + multi + topo) + in_ch_isr = config.data_config["img_channel"] + 10 + 1 # all channels from current low res : 1 + 4 + 1 : precip + multi + topo + + else: + + in_ch_model = 2 * config.data_config["img_channel"] + in_ch_flow = 3 * config.data_config["img_channel"] + in_ch_isr = config.data_config["img_channel"] + + if config.data_config["flow"] == "3d": + + out_ch_flow = 3 + + elif config.data_config["flow"] == "2d": + + out_ch_flow = 2 + + model = Unet( + dim = config.dim, + channels = in_ch_model, + out_dim = config.data_config["img_channel"], + dim_mults = config.dim_mults, + learned_sinusoidal_cond = config.learned_sinusoidal_cond, + random_fourier_features = config.random_fourier_features, + learned_sinusoidal_dim = config.learned_sinusoidal_dim + ).cuda() + + flow = Flow( + dim = config.dim, + channels = in_ch_flow, + out_dim = out_ch_flow, + dim_mults = config.dim_mults + ).cuda() + + diffusion = GaussianDiffusion( + model, + flow, + image_size = config.data_config["img_size"], + in_ch = in_ch_isr, + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective + ).cuda() + + train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 4, + ) + + trainer = Trainer( + diffusion, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + ) + + trainer.train() + + +if __name__ == "__main__": + print(config) + main() \ No newline at end of file diff --git a/projects/super_res/trainer_isr.py b/projects/super_res/trainer_isr.py new file mode 100644 index 0000000000..18afce64bd --- /dev/null +++ b/projects/super_res/trainer_isr.py @@ -0,0 +1,45 @@ +import os + +from model.isr_baseline import Trainer +from model.network_swinir import SwinIR +from data.load_data import load_data +from config_isr import config + +def main(): + model = SwinIR(upscale=8, in_chans=12, out_chans=1, img_size=48, window_size=8, + img_range=1., depths=[6, 6, 6, 6, 6, 6, 6], embed_dim=200, + num_heads=[8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, upsampler='pixelshuffle', resi_connection='3conv').cuda() + + train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 4, + ) + + trainer = Trainer( + model, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config + #tensorboard_dir = os.path.join(config.tensorboard_dir, f"{config.model_name}/"), + ) + + trainer.train() + + +if __name__ == "__main__": + print(config) + main() diff --git a/projects/super_res/trainer_rvrt_full.py b/projects/super_res/trainer_rvrt_full.py new file mode 100644 index 0000000000..8f626eb9aa --- /dev/null +++ b/projects/super_res/trainer_rvrt_full.py @@ -0,0 +1,109 @@ +import os +from torch import nn +from model.denoising_diffusion_rvrt_full import RSTBWithInputConv, Upsample, GuidedDeformAttnPack, GaussianDiffusion, SpyNet, Trainer +from data.load_data import load_data +from config_rvrt_full import config + +recon = RSTBWithInputConv( + in_channels = 5 * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_ext = RSTBWithInputConv( + in_channels = config.data_config["img_channel"]+11, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 1, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (1,8,8) +).cuda() + +feat_up = Upsample( + scale = 8, + num_feat = config.dim, + in_channels = config.data_config["img_channel"] +).cuda() + +spynet = SpyNet('./spynet').cuda() + +backbone = nn.ModuleDict() +deform_align = nn.ModuleDict()\ + +modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + +for i, module in enumerate(modules): + # deformable attention + deform_align[module] = GuidedDeformAttnPack(config.dim, + config.dim, + attention_window=[3, 3], + attention_heads=6, + deformable_groups=6, + clip_size=2, + max_residue_magnitude=10).cuda() + + # feature propagation + backbone[module] = RSTBWithInputConv( + in_channels = (2 + i) * config.dim, + kernel_size = (1, 3, 3), + stride = 1, + groups = 1, + num_blocks = 2, + dim = config.dim, + input_resolution = config.data_config["img_size"], + num_heads = 6, + depth = 2, + window_size = (2,8,8) + ).cuda() + +diffusion = GaussianDiffusion( + feat_ext = feat_ext, + feat_up = feat_up, + backbone = backbone, + deform_align = deform_align, + recon = recon, + spynet = spynet, + image_size = config.data_config["img_size"], + timesteps = config.diffusion_steps, + sampling_timesteps = config.sampling_steps, + loss_type = config.loss, + objective = config.objective +).cuda() + +train_dl, val_dl = load_data( + config.data_config, + config.batch_size, + pin_memory = True, + num_workers = 2, + ) + +trainer = Trainer( + diffusion, + train_dl, + val_dl, + train_batch_size = config.batch_size, + train_lr = config.lr, + train_num_steps = config.steps, + gradient_accumulate_every = config.grad_acc, + val_num_of_batch = config.val_num_of_batch, + save_and_sample_every = config.save_and_sample_every, + ema_decay = config.ema_decay, + amp = config.amp, + split_batches = config.split_batches, + eval_folder = os.path.join(config.eval_folder, f"{config.model_name}/"), + results_folder = os.path.join(config.results_folder, f"{config.model_name}/"), + config = config +) + +trainer.train() \ No newline at end of file