Skip to content

Commit

Permalink
Resolved issue #74 (#79)
Browse files Browse the repository at this point in the history
* updated the time slice and ran the file run_fulll.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: vishal <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 036f70c commit 7464326
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
.idea/
*.nc
*.pt
*.pyc
.DS_Store
*.txt
20 changes: 16 additions & 4 deletions train/run_fulll.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Training script for training the weather forecasting model"""
import json
import os
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

import numpy as np
import torch
Expand Down Expand Up @@ -31,8 +36,8 @@ def __len__(self):
return len(self.filepaths)

def __getitem__(self, item):
start_idx = np.random.randint(0, 15)
data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 1))
start_idx = np.random.randint(0, 14)
data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 2))

start = data.isel(time=0)
end = data.isel(time=1)
Expand Down Expand Up @@ -83,7 +88,14 @@ def __getitem__(self, item):
print(data)
# print("Done coarsening")
lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2)
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

# Get the variance of the variables
feature_variances = []
for var in data.data_vars:
Expand All @@ -93,7 +105,7 @@ def __getitem__(self, item):
lat_lons=lat_lons, feature_variance=feature_variances, device=device
).to(device)
means = []
dataset = DataLoader(XrDataset(), batch_size=1, num_workers=32)
dataset = DataLoader(XrDataset(), batch_size=1)
model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
print("Done Setup")
Expand Down

0 comments on commit 7464326

Please sign in to comment.