Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

finite_ntk malaria data generation bug + questions #85

Open
tingtang2 opened this issue Apr 14, 2022 · 1 comment
Open

finite_ntk malaria data generation bug + questions #85

tingtang2 opened this issue Apr 14, 2022 · 1 comment

Comments

@tingtang2
Copy link

tingtang2 commented Apr 14, 2022

Bug

I think test_year is supposed to be train_year in line 55 right?

xfer/finite_ntk/data.py

Lines 40 to 55 in dd4a6a2

def generate_data(
nsamples=2000, train_year=2012, test_year=2016, grid_size=200, seed=110, hdf_loc=None
):
r"""
generates subsampled dataset from the hdf_location given years, grids, etc.
nsamples (int): dataset size
train_year (int): year to use from dataset
test_year (int): year to test on from dataset
grid_size (int): size of grid
seed (int): random seed for subsampling
hdf_loc (str): location of dataset hdf5 file
"""
df = pd.read_hdf(hdf_loc, "full")
is_train_year = torch.from_numpy((df["year"] == test_year).values)

Questions

I see the variables inside, extent and grid_x being declared but not used in the malaria experiments. I'm looking to replicate the experiments in JAX so I was wondering what the original purpose of these variables were. In particular, what is the sparse tensor for marking Nigeria and inside supposed to be doing?

xfer/finite_ntk/data.py

Lines 80 to 100 in dd4a6a2

# Generate nxn grid of test points spaced on a grid of size 1/(n-1) in [0,1]x[0,1] for evaluation
n = grid_size
L = torch.linspace(0, 1, n)
X, Y = torch.meshgrid(L, L)
grid_x = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)
# let's start with a small set of samples
train_inputs, train_targets, train_targets_var = subsample(
train_inputs, train_targets, train_targets_var, nsamples, seed=seed
)
# mark nigeria - not great but works reasonably well
ng_coords = (n * all_x[:, :2]).round().long()[is_ng]
sparse_ng = torch.sparse.LongTensor(
ng_coords.transpose(0, 1),
torch.ones(ng_coords.size(0)).long(),
torch.Size([n, n]),
)
inside = sparse_ng.to_dense().reshape(-1) > 0
return train_inputs, train_targets, test_x, test_y, inside, extent

P.S. thank you for the work and for releasing the code!

@wjmaddox
Copy link
Contributor

Ugh, yeah, that looks like a bug. Let me put up a PR to fix (idk who has write access to this still however).

For the second question, you shouldn't need to copy those over.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants