Skip to content

Commit

Permalink
update based on comments
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Nov 2, 2023
1 parent 0b7820e commit c205c5a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 10 deletions.
9 changes: 8 additions & 1 deletion monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np

from monai.apps.tcia import (
DCM_FILENAME_REGEX,
download_tcia_series_instance,
get_tcia_metadata,
get_tcia_ref_uid,
Expand Down Expand Up @@ -442,6 +443,10 @@ class TciaDataset(Randomizable, CacheDataset):
specific_tags: tags that will be loaded for "SEG" series. This argument will be used in
`monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),
(0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].
fname_regex: a regular expression to match the file names when the input is a folder.
If provided, only the matched files will be included. For example, to include the file name
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`.
Default to `"^(?!.*LICENSE).*"`, ignoring any file name containing `"LICENSE"`.
val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
note to set same seed for `training` and `validation` sections.
Expand Down Expand Up @@ -509,6 +514,7 @@ def __init__(
(0x0020, 0x0011), # Series Number
(0x0020, 0x0012), # Acquisition Number
),
fname_regex: str = DCM_FILENAME_REGEX,
seed: int = 0,
val_frac: float = 0.2,
cache_num: int = sys.maxsize,
Expand Down Expand Up @@ -548,12 +554,13 @@ def __init__(

if not os.path.exists(download_dir):
raise RuntimeError(f"Cannot find dataset directory: {download_dir}.")
self.fname_regex = fname_regex

self.indices: np.ndarray = np.array([])
self.datalist = self._generate_data_list(download_dir)

if transform == ():
transform = LoadImaged(reader="PydicomReader", keys=["image"])
transform = LoadImaged(keys=["image"], reader="PydicomReader", fname_regex=self.fname_regex)
CacheDataset.__init__(
self,
data=self.datalist,
Expand Down
9 changes: 8 additions & 1 deletion monai/apps/tcia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@
from __future__ import annotations

from .label_desc import TCIA_LABEL_DICT
from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study
from .utils import (
BASE_URL,
DCM_FILENAME_REGEX,
download_tcia_series_instance,
get_tcia_metadata,
get_tcia_ref_uid,
match_tcia_ref_uid_in_study,
)
12 changes: 10 additions & 2 deletions monai/apps/tcia/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
requests_get, has_requests = optional_import("requests", name="get")
pd, has_pandas = optional_import("pandas")

__all__ = ["get_tcia_metadata", "download_tcia_series_instance", "get_tcia_ref_uid", "match_tcia_ref_uid_in_study"]

DCM_FILENAME_REGEX = r"^(?!.*LICENSE).*" # excluding the file with "LICENSE" in its name
BASE_URL = "https://services.cancerimagingarchive.net/nbia-api/services/v1/"

__all__ = [
"get_tcia_metadata",
"download_tcia_series_instance",
"get_tcia_ref_uid",
"match_tcia_ref_uid_in_study",
"DCM_FILENAME_REGEX",
"BASE_URL",
]


def get_tcia_metadata(query: str, attribute: str | None = None) -> list:
"""
Expand Down
5 changes: 2 additions & 3 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ class PydicomReader(ImageReader):
for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}.
fname_regex: a regular expression to match the file names when the input is a folder.
If provided, only the matched files will be included. For example, to include the file name
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`.
Default to `"^(?!.*LICENSE).*"`, ignoring any file name containing `"LICENSE"`.
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`.
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
If the `get_data` function will be called
Expand All @@ -423,7 +422,7 @@ def __init__(
swap_ij: bool = True,
prune_metadata: bool = True,
label_dict: dict | None = None,
fname_regex: str = r"^(?!.*LICENSE).*",
fname_regex: str = "",
**kwargs,
):
super().__init__()
Expand Down
25 changes: 22 additions & 3 deletions tests/test_tciadataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import unittest

from monai.apps import TciaDataset
from monai.apps.tcia import TCIA_LABEL_DICT
from monai.apps.tcia import DCM_FILENAME_REGEX, TCIA_LABEL_DICT
from monai.data import MetaTensor
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd
from tests.utils import skip_if_downloading_fails, skip_if_quick
Expand All @@ -32,7 +32,12 @@ def test_values(self):

transform = Compose(
[
LoadImaged(keys=["image", "seg"], reader="PydicomReader", label_dict=TCIA_LABEL_DICT[collection]),
LoadImaged(
keys=["image", "seg"],
reader="PydicomReader",
fname_regex=DCM_FILENAME_REGEX,
label_dict=TCIA_LABEL_DICT[collection],
),
EnsureChannelFirstd(keys="image", channel_dim="no_channel"),
ScaleIntensityd(keys="image"),
]
Expand Down Expand Up @@ -82,10 +87,24 @@ def _test_dataset(dataset):
self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
self.assertEqual(len(data), int(download_len * val_frac))
data = TciaDataset(
root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac
root_dir=testing_dir,
collection=collection,
section="validation",
download=False,
fname_regex=DCM_FILENAME_REGEX,
val_frac=val_frac,
)
self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24))
self.assertEqual(len(data), download_len)
with self.assertRaises(RuntimeError):
data = TciaDataset(
root_dir=testing_dir,
collection=collection,
section="validation",
fname_regex=".*", # all files including 'LICENSE' is not a valid input
download=False,
val_frac=val_frac,
)[0]

shutil.rmtree(os.path.join(testing_dir, collection))
try:
Expand Down

0 comments on commit c205c5a

Please sign in to comment.