Skip to content

Commit

Permalink
Merge branch 'main' into 29-Process-clipping-of-Sentinel-tiles-on-mon…
Browse files Browse the repository at this point in the history
…thly-basis
  • Loading branch information
katyagikalo authored Oct 28, 2024
2 parents 337c0aa + 8b95763 commit 328a5a4
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Changes from previous releases are listed below.
- Renaming and moving of Sentinel-2 bands _(see #42)_
- Process clipping of Sentinel tiles on monthly basis _(see #29)_
- Adjust the shapefile CRS transformation for Sentinel-1 _(see #49)_
- Adjusting split generation _(see #45)_

## 0.3.1 (2024-07-29)
- Remove country_code variable in collector downloader _(see #33)_
Expand Down
7 changes: 4 additions & 3 deletions eurocropsml/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ def to_tuple(self) -> tuple[DataItem, torch.Tensor]:
return self.data_item, self.label


def custom_collate_fn(batch: Sequence[LabelledData]) -> LabelledData:
def custom_collate_fn(batch: Sequence[LabelledData], padding_value: int = -1) -> LabelledData:
"""Collate function for batch creation within data loader.
Used to create batches from a dataset's DataItem.
Args:
batch: List of DataItem from dataset
padding_value: Value used for padding.
Returns:
New DataItem with batched data.
Expand All @@ -112,7 +113,7 @@ def custom_collate_fn(batch: Sequence[LabelledData]) -> LabelledData:
tensor_name: (
torch.stack(tensors)
if tensor_stackability[tensor_name]
else pad_sequence(tensors, batch_first=True, padding_value=-1)
else pad_sequence(tensors, batch_first=True, padding_value=padding_value)
)
for tensor_name, tensors in batch_tensors.items()
}
Expand All @@ -123,7 +124,7 @@ def custom_collate_fn(batch: Sequence[LabelledData]) -> LabelledData:
):
batched_tensors["label"] = torch.concat(batch_tensors["label"], 0)

if (pad_mask := batched_tensors["data"].eq(-1)).any():
if (pad_mask := batched_tensors["data"].eq(padding_value)).any():
if (aug_mask := batched_tensors.get("mask")) is not None:
# if pad_mask.dim (B, T, C) != aug_mask.dim (B, T, C) | (B, T)
# => aug_mask.dim is (B, T), thus pad_mask need to be converted
Expand Down
5 changes: 3 additions & 2 deletions eurocropsml/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ def __getitem__(
meta_data: dict[str, torch.Tensor] = {
array_name: torch.tensor(np_array) for array_name, np_array in arrays_dict.items()
}
if self.pad_seq_to_366:
np_data = pad_seq_to_366(np_data, meta_data["dates"])

tensor_data = torch.tensor(np_data, dtype=torch.float)
if self.config.normalize:
tensor_data = torch.mul(tensor_data, NORMALIZING_FACTOR)

if self.pad_seq_to_366:
tensor_data = pad_seq_to_366(tensor_data, meta_data["dates"])

y = int(Path(f).stem.split("_")[-1])
# encode class
y = self.encode[y]
Expand Down
39 changes: 26 additions & 13 deletions eurocropsml/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import hashlib
import json
import logging
import random
from collections import defaultdict
from collections.abc import Iterable
from functools import cached_property, partial
Expand Down Expand Up @@ -242,7 +241,7 @@ def _create_finetune_set(
finetune_list: list[str] = [
value for values_list in finetune_dataset.values() for value in values_list
]

finetune_list.sort()
if set(pretrain_list) & set(finetune_list):
raise Exception(
f"There are {len((set(pretrain_list) & set(finetune_list)))} "
Expand All @@ -261,9 +260,19 @@ def _create_finetune_set(
if num_samples["test"] != "all":
num_samples["test"] = int(cast(int, num_samples["test"]))
if isinstance(num_samples["validation"], int) and len(finetune_val) > num_samples["validation"]:
finetune_val = random.sample(finetune_val, num_samples["validation"])
finetune_val = resample(
finetune_val,
replace=False,
n_samples=num_samples["validation"],
random_state=seed,
)
if isinstance(num_samples["test"], int) and len(finetune_test) > num_samples["test"]:
finetune_test = random.sample(finetune_test, num_samples["test"])
finetune_test = resample(
finetune_test,
replace=False,
n_samples=num_samples["test"],
random_state=seed,
)

sample_list: list[str | int]
if isinstance(num_samples["train"], list):
Expand Down Expand Up @@ -343,6 +352,7 @@ def split_dataset_by_class(
seed,
)

pretrain_list.sort()
# save pretraining split
train, val = train_test_split(pretrain_list, test_size=test_size, random_state=seed)

Expand Down Expand Up @@ -431,6 +441,7 @@ def split_dataset_by_region(
if (
finetune_dataset is not None and finetune_regions is not None
): # otherwise EuroCrops is solely used for pretraining

finetune_dataset = _filter_regions(finetune_dataset, finetune_regions)

_create_finetune_set(
Expand All @@ -442,6 +453,7 @@ def split_dataset_by_region(
test_size,
seed,
)
pretrain_list.sort()

# save pretraining split
train, val = train_test_split(pretrain_list, test_size=test_size, random_state=seed)
Expand All @@ -460,30 +472,31 @@ def split_dataset_by_region(
_save_to_json(split_dir.joinpath("meta", f"{split}_split.json"), meta_dict)


def pad_seq_to_366(seq: np.ndarray, dates: torch.Tensor) -> np.ndarray:
def pad_seq_to_366(seq: torch.Tensor, dates: torch.Tensor, padding_value: int = -1) -> torch.Tensor:
"""Pad sequence to 366 days.
Args:
seq: Array containing sequence data to be padded.
dates: Array of matching size specifying the dates
seq: Tensor containing sequence data to be padded.
dates: Tensor of matching size specifying the dates
associated to each the sequences data point.
padding_value: Value used for padding.
Returns:
A padded sequence data array with all missing dates
filled in by a `-1` mask value.
A padded sequence data array with all missing dates filled in by a
`padding_value`-mask value.
"""
rg = range(366)

df_data = pd.DataFrame(np.array(seq).T.tolist(), columns=dates.tolist())
df_data = pd.DataFrame(seq.T.tolist(), columns=dates.tolist())
df_dates = pd.DataFrame(columns=rg, dtype=int)
df_dates = pd.concat([df_dates, df_data])

df_dates = df_dates.fillna(-1)
df_dates = df_dates.fillna(padding_value)

pad_seq: list = [df_dates[col].to_numpy() for col in rg]
pad_seq: np.ndarray = np.array([df_dates[col].to_numpy() for col in rg])

return np.array(pad_seq)
return torch.Tensor(pad_seq)


class MMapMetadata:
Expand Down
2 changes: 1 addition & 1 deletion tests/dataset/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_data() -> list[str]:

def test_pad_seq_to_366(test_arrays: tuple[np.ndarray, torch.Tensor]) -> None:
test_data, test_dates = test_arrays
np_data = pad_seq_to_366(test_data, test_dates)
np_data = pad_seq_to_366(torch.tensor(test_data, dtype=torch.float), test_dates)

assert np_data.shape == (366, test_data.shape[1])

Expand Down

0 comments on commit 328a5a4

Please sign in to comment.