Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make padding compatible with old S2 data #57

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Changes from previous releases are listed below.
- Fasten up padding to 366 days _(see #54)_
- Radiometric calibration of Sentinel-1 _(see #47)_
- Downloading Sentinel-1 data and make it usable together with Sentinel-2 _(see #43)_
- Make padding compatible with old S2 data _(see #56)_
- Enhancing download process _(see #58)_
jmaces marked this conversation as resolved.
Show resolved Hide resolved
- Load S2 split from Zenodo for benchmark _(see #59)_

## 0.3.1 (2024-07-29)
- Remove country_code variable in collector downloader _(see #33)_
Expand Down
4 changes: 3 additions & 1 deletion docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ The $\texttt{EuroCropsML}$ dataset allows users to customize options for various
| `data_dir` | Folder inside the data directory where pre-processed data is stored. |
| `random_seed` | Random seed used for generating training-testing-splits and further random numbers. |
| `num_samples` | Number of samples per class used for the fine-tuning subsets. The default will create the shots currently present on [Zenodo](https://zenodo.org/doi/10.5281/zenodo.10629609) for the training set. It will samples 1000 samples for validation and keep all available data from the test set. |
| `satellites` | List of satellites whose data is to be used.|
| `benchmark` | Whether to download the pre-existing benchmark split from Zenodo. For more information see {any}`split_dataset_by_region<eurocropsml.dataset.splits.split_dataset_by_region>` |
| `meadow_class` | Class that represents the ${\texttt{pasture_meadow_grassland_grass}}$ class. If provided, then this class will be downsampled to the median frequency of all other classes for the pre-training dataset since it represents an imbalanced majority class. |
| `pretrain_classes` | Classes that make up the pre-train dataset. |
| `finetune_classes` | Classes that make up the pre-train dataset. |
| `finetune_classes` | Classes that make up the fine-tune dataset. |
| `pretrain_regions` | Regions that make up the pre-train dataset. |
| `finetune_regions` | Regions that make up the fine-tune dataset. |

Expand Down
4 changes: 3 additions & 1 deletion eurocropsml/dataset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def build_splits(
overrides: OverridesT = typer.Argument(None, help="Overrides to split config"),
) -> None:
config = build_config(overrides, config_path)
create_splits(config.split, config.preprocess.raw_data_dir.parent)
create_splits(
config.split, config.preprocess.raw_data_dir.parent, config.preprocess.download_url
)

return app

Expand Down
2 changes: 1 addition & 1 deletion eurocropsml/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class EuroCropsSplit(BaseModel):

satellite: list[Literal["S1", "S2"]] = ["S2"]

benchmark: bool = True
benchmark: bool = False

pretrain_classes: dict[str, list[int]]
finetune_classes: dict[str, list[int]] = {}
Expand Down
117 changes: 62 additions & 55 deletions eurocropsml/dataset/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,37 @@
logger = logging.getLogger(__name__)


def _get_zenodo_record(
base_url: str, version_number: int | None = None
) -> tuple[dict, list[str]] | dict:
response: requests.models.Response = requests.get(base_url)
response.raise_for_status()
data = response.json()
versions: list[dict] = data["hits"]["hits"]

if versions:
if version_number is not None:
selected_version = next(
(
v
for v in versions
if v["metadata"]["relations"]["version"][0]["index"] + 1 == version_number
),
None,
)
if selected_version is not None:
return selected_version
else:
logger.error(f"Version {version_number} could not be found on Zenodo.")
sys.exit(1)
else:
selected_version, files_to_download = select_version(versions)
return selected_version, files_to_download
else:
logger.error("No data found on Zenodo. Please download manually.")
sys.exit(1)


def _download_file(
file_name: str, file_url: str, local_path: Path, downloadfile_md5_hash: str
) -> None:
Expand All @@ -31,7 +62,7 @@ def _download_file(
" Do you want to delete it and redownload the file?"
)
if download:
logger.info(f"Downloading {local_path}...")
logger.info(f"Downloading to {local_path}...")
try:
response = requests.get(file_url)
response.raise_for_status()
Expand All @@ -47,23 +78,16 @@ def _download_file(
logger.info(f"{local_path} will not be downloaded again.")


def get_user_choice() -> list[str]:
def get_user_choice(files_to_download: list[str]) -> list[str]:
"""Get user choice for which files to download."""
choice = input(
"Would you like to download Sentinel-1 and/or Sentinel-2 data? Please enter "
" 'S1', 'S2', or 'both': "
)
if choice not in {"S1", "S2", "both"}:
logger.error("Invalid input. Please enter 'S1', 'S2', or 'both'.")
sys.exit(1)
elif choice == "both":
choice_list = ["S1", "S2"]
logger.info("Downloading both S1 and S2 data.")
else:
logger.info(f"Downloading only {choice} data.")
choice_list = [choice]
logger.info("Choose one or more of the following options by typing their numbers (e.g., 1 3):")
for i, file in enumerate(files_to_download, 1):
logger.info(f"{i}. {file}")
choice = typer.prompt("Enter your choices separated by spaces: ")
selected_indices = [int(choice) - 1 for choice in choice.split()]
selected_options = [files_to_download[i] for i in selected_indices]

return choice_list
return selected_options


def select_version(versions: list[dict]) -> tuple[dict, list[str]]:
Expand Down Expand Up @@ -150,46 +174,29 @@ def download_dataset(preprocess_config: EuroCropsDatasetPreprocessConfig) -> Non
data_dir: Path = Path(preprocess_config.raw_data_dir.parent)
data_dir.mkdir(exist_ok=True, parents=True)

response: requests.models.Response = requests.get(base_url)

try:
response.raise_for_status()
data = response.json()
versions: list[dict] = data["hits"]["hits"]

if versions:
selected_version, files_to_download = select_version(versions)

# older version do only have S2 data
# if S1 data is available, let user decide
if "S1.zip" in files_to_download:
user_choice = get_user_choice()
if "S1" not in user_choice:
files_to_download.remove("S1.zip")
if "S2" not in user_choice:
files_to_download.remove("S2.zip")

for file_entry in selected_version["files"]:
file_url: str = file_entry["links"]["self"]
zip_file: str = file_entry["key"]
if zip_file in files_to_download:
local_path: Path = data_dir.joinpath(zip_file)
_download_file(zip_file, file_url, local_path, file_entry.get("checksum", ""))
logger.info(f"Unzipping {local_path}...")
_unzip_file(local_path, data_dir)

# move S1 and S2 data
if zip_file in ["S1.zip", "S2.zip"]:
unzipped_path: Path = local_path.with_suffix("")
for folder in unzipped_path.iterdir():
rel_target_folder: Path = folder.relative_to(unzipped_path)
_move_files(
folder, data_dir.joinpath(rel_target_folder, zip_file.split(".")[0])
)
shutil.rmtree(unzipped_path)
else:
logger.error("No data found on Zenodo. Please download manually.")
sys.exit(1)
selected_version, files_to_download = _get_zenodo_record(base_url)
# let user decide what data to download
selected_files = get_user_choice(files_to_download)

for file_entry in selected_version["files"]:
file_url: str = file_entry["links"]["self"]
zip_file: str = file_entry["key"]
if zip_file in selected_files:
local_path: Path = data_dir.parent.joinpath(zip_file)
_download_file(zip_file, file_url, local_path, file_entry.get("checksum", ""))
logger.info(f"Unzipping {local_path}...")
_unzip_file(local_path, data_dir)

# move S1 and S2 data
if zip_file in ["S1.zip", "S2.zip"]:
unzipped_path: Path = local_path.with_suffix("")
for folder in unzipped_path.iterdir():
rel_target_folder: Path = folder.relative_to(unzipped_path)
_move_files(
folder, data_dir.joinpath(rel_target_folder, zip_file.split(".")[0])
)
shutil.rmtree(unzipped_path)

except requests.exceptions.HTTPError as err:
logger.warning(f"There was an error when trying to access the Zenodo record: {err}")
20 changes: 13 additions & 7 deletions eurocropsml/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ def preprocess(
) -> None:
"""Run preprocessing."""

raw_data_dir = preprocess_config.raw_data_dir
num_workers = preprocess_config.num_workers
satellite = preprocess_config.satellite
preprocess_dir = preprocess_config.preprocess_dir / satellite
num_workers: int | None = preprocess_config.num_workers
satellite: str = preprocess_config.satellite
raw_data_dir: Path = preprocess_config.raw_data_dir
raw_data_dir_satellite: Path = preprocess_config.raw_data_dir / satellite
preprocess_dir: Path = preprocess_config.preprocess_dir / satellite

if preprocess_config.bands is None:
if satellite == "S2":
Expand All @@ -206,19 +207,18 @@ def preprocess(
else:
bands = preprocess_config.bands

if preprocess_dir.exists() and len(list((preprocess_dir.iterdir()))) > 0:
if preprocess_dir.exists() and any(preprocess_dir.iterdir()):
logger.info(
f"Preprocessing directory {preprocess_dir} already exists and contains data. "
"Nothing to do."
)
sys.exit(0)

if raw_data_dir.exists():
if raw_data_dir_satellite.exists():
logger.info("Raw data directory exists. Skipping download.")

logger.info("Starting preprocessing. Compiling labels and centerpoints of parcels")
preprocess_dir.mkdir(exist_ok=True, parents=True)
raw_data_dir_satellite: Path = raw_data_dir / satellite
for file_path in raw_data_dir_satellite.glob("*.parquet"):
country_file: pd.DataFrame = pd.read_parquet(file_path).set_index("parcel_id")
cols = country_file.columns.tolist()
Expand Down Expand Up @@ -251,6 +251,12 @@ def preprocess(
lambda y: np.array([-999] * b) if y is None else y
)
)
if satellite == "S2":
region_data = region_data.apply(
lambda x, b=len(bands): x.map(
lambda y: np.array([-999] * b) if y == [0] * b else y
)
)
with Pool(processes=num_workers) as p:
func = partial(
_save_row,
Expand Down
Loading