Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2d flow #2301

Open
wants to merge 9 commits into
base: feature/diff-code
Choose a base branch
from
Prev Previous commit
Next Next commit
refactor : 2d, 3d flow in one file, switches for logscale, multichann…
…el, minipatch, partial rollout
  • Loading branch information
Prakhar Srivastava committed Aug 11, 2023

Verified

This commit was signed with the committer’s verified signature. The key has been revoked.
eggplants haruna
commit 40989d3671bf4435933f7f0a495675bb25fff9a6
17 changes: 9 additions & 8 deletions projects/super_res/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from ml_collections import config_dict

#batch_size = 4
config = config_dict.ConfigDict()

config.dim = 64
config.dim_mults = (1, 1, 2, 2, 3, 4)
config.dim = 128
config.dim_mults = (1, 2, 2, 2, 4, 4)
config.learned_sinusoidal_cond = True,
config.random_fourier_features = True,
config.learned_sinusoidal_dim = 32
@@ -20,23 +19,25 @@
config.ema_decay = 0.995
config.amp = False
config.split_batches = True
config.additional_note = "no-logscale"
config.additional_note = "multichannel_minipatch"
config.eval_folder = "./evaluate"
config.results_folder = "./results"
config.tensorboard_dir = "./tensorboard"
config.milestone = 1
config.rollout = None
config.rollout_batch = None

config.batch_size = 1
config.data_config = config_dict.ConfigDict({
"dataset_name": "c384",
"length": 7,
#"channels": ["UGRD10m_coarse","VGRD10m_coarse"],
"channels": ["PRATEsfc_coarse"],
#"img_channel": 2,
"img_channel": 1,
"img_size": 384,
"logscale": False,
"quick": True
"logscale": True,
"multi": True,
"flow": "2d",
"minipatch": False
})

config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}"
17 changes: 9 additions & 8 deletions projects/super_res/config_infer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from ml_collections import config_dict

#batch_size = 4
config = config_dict.ConfigDict()

config.dim = 64
config.dim_mults = (1, 1, 2, 2, 3, 4)
config.dim = 128
config.dim_mults = (1, 2, 2, 2, 4, 4)
config.learned_sinusoidal_cond = True,
config.random_fourier_features = True,
config.learned_sinusoidal_dim = 32
config.diffusion_steps = 1500
config.sampling_steps = 15
config.sampling_steps = 20
config.loss = "l2"
config.objective = "pred_v"
config.lr = 8e-5
@@ -20,23 +19,25 @@
config.ema_decay = 0.995
config.amp = False
config.split_batches = True
config.additional_note = ""
config.additional_note = "multichannel_minipatch"
config.eval_folder = "./evaluate"
config.results_folder = "./results"
config.tensorboard_dir = "./tensorboard"
config.milestone = 1
config.rollout = "partial"
config.rollout_batch = 25

config.batch_size = 1
config.data_config = config_dict.ConfigDict({
"dataset_name": "c384",
"length": 7,
#"channels": ["UGRD10m_coarse","VGRD10m_coarse"],
"channels": ["PRATEsfc_coarse"],
#"img_channel": 2,
"img_channel": 1,
"img_size": 384,
"logscale": True,
"quick": True
"multi": True,
"flow": "2d",
"minipatch": False
})

config.data_name = f"{config.data_config['dataset_name']}-{config.data_config['channels']}-{config.objective}-{config.loss}-d{config.dim}-t{config.diffusion_steps}{config.additional_note}"
43 changes: 0 additions & 43 deletions projects/super_res/config_mod_flow.py

This file was deleted.

7 changes: 3 additions & 4 deletions projects/super_res/data/load_dataset.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,13 @@

def load_dataset(data_config):

channels = data_config["channels"]
length = data_config["length"]
logscale = data_config["logscale"]
quick = data_config["quick"]
multi = data_config["multi"]

train, val = None, None

train = VSRDataset(channels, 'train', length, logscale, quick)
val = VSRDataset(channels, 'val', length, logscale, quick)
train = VSRDataset('train', length, logscale, multi)
val = VSRDataset('val', length, logscale, multi)

return train, val
81 changes: 31 additions & 50 deletions projects/super_res/data/vsrdata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import xarray as xr
import numpy as np
from torch.utils.data import Dataset

class VSRDataset(Dataset):

def __init__(self, channels, mode, length, logscale = False, quick = True):
def __init__(self, mode, length, logscale = False, multi = False):
'''
Args:
channels (list): list of channels to use
@@ -20,61 +19,43 @@ def __init__(self, channels, mode, length, logscale = False, quick = True):
# mode
self.mode = mode

if not quick:
# load data from bucket
# shape : (tile, time, y, x)
c384 = xr.open_zarr("gs://vcm-ml-raw-flexible-retention/2021-07-19-PIRE/C3072-to-C384-res-diagnostics/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt_coarse": "x", "grid_yt_coarse": "y"})
c48 = xr.open_zarr("gs://vcm-ml-intermediate/2021-10-12-PIRE-c48-post-spinup-verification/pire_atmos_phys_3h_coarse.zarr").rename({"grid_xt": "x", "grid_yt": "y"})

# convert to numpy
# shape : (tile, time, channel, y, x)
c384_np = np.stack([c384[channel].values for channel in channels], axis = 2)
c48_np = np.stack([c48[channel].values for channel in channels], axis = 2)
# data shape : (num_tiles, num_frames, num_channels, height, width)
# num_tiles = 6; num_frames = 2920, num_channels = 1
if logscale:

if logscale:
c384_np = np.log(c384_np - c384_np.min() + 1e-14)
c48_np = np.log(c48_np - c48_np.min() + 1e-14)
c384_norm= np.load("data/only_precip/c384_lgnorm.npy")
c48_norm = np.load("data/only_precip/c48_lgnorm.npy")

# calculate split (80/20)
split = int(c384_np.shape[1] * 0.8)
else:

# compute statistics on training set
c384_min, c384_max, c48_min, c48_max = c384_np[:, :split, :, :, :].min(), c384_np[:, :split, :, :, :].max(), c48_np[:, :split, :, :, :].min(), c48_np[:, :split, :, :, :].max()
c384_norm= np.load("data/only_precip/c384_norm.npy")
c48_norm = np.load("data/only_precip/c48_norm.npy")

t, f, c, h, w = c384_norm.shape

# normalize
c384_norm= (c384_np - c384_min) / (c384_max - c384_min)
c48_norm = (c48_np - c48_min) / (c48_max - c48_min)
if multi:

if mode == 'train':

self.X = c48_norm[:, :split, :, :, :]
self.y = c384_norm[:, :split, :, :, :]

elif mode == 'val':

self.X = c48_norm[:, split:, :, :, :]
self.y = c384_norm[:, split:, :, :, :]
# load more channels, order : ("UGRD10m_coarse", "VGRD10m_coarse", "tsfc_coarse", "CPRATEsfc_coarse")
c48_norm_more = np.load("data/more_channels/c48_norm.npy")
c48_norm = np.concatenate((c48_norm, c48_norm_more), axis = 2)

else:
if logscale:
c384_norm= np.load("data/only_precip/c384_lgnorm.npy")
c48_norm = np.load("data/only_precip/c48_lgnorm.npy")
else:
c384_norm= np.load("data/only_precip/c384_norm.npy")
c48_norm = np.load("data/only_precip/c48_norm.npy")
# load topography, shape : (num_tiles, height, width)
# reshaping to match data shape
topo384 = np.repeat(np.load("data/topography/topo384_norm.npy").reshape((t, 1, c, 384, 384)), f, axis = 1)
c384_norm = np.concatenate((c384_norm, topo384), axis = 2)

# calculate split (80/20)
split = int(c384_norm.shape[1] * 0.8)
# calculate split (80/20)
split = int(c384_norm.shape[1] * 0.8)

if mode == 'train':
self.X = c48_norm[:, :split, :, :, :]
self.y = c384_norm[:, :split, :, :, :]
elif mode == 'val':
self.X = c48_norm[:, split:, :, :, :]
self.y = c384_norm[:, split:, :, :, :]
if mode == 'train':

self.X = c48_norm[:, :split, :, :, :]
self.y = c384_norm[:, :split, :, :, :]

elif mode == 'val':

self.X = c48_norm[:, split:, :, :, :]
self.y = c384_norm[:, split:, :, :, :]

def __len__(self):

@@ -83,13 +64,13 @@ def __len__(self):
def __getitem__(self, idx):

# load a random tile index

if self.mode == 'train':
tile = np.random.randint(0, self.X.shape[0])

elif self.mode == 'val':
tile = 0

# tensor shape : (length, num_channels, height, width)
lowres = self.X[tile, idx:idx+self.length, :, :, :]
highres = self.y[tile, idx:idx+self.length, :, :, :]

Loading