Skip to content

Commit

Permalink
Add filesystem abstraction using fsspec (pyg-team#8336).
Browse files Browse the repository at this point in the history
**Changes made:**
- Add fs_utils.py for API that offers 1:1 replacement for current
filesystem usage.
- Update usages of open, glob, makedirs, normpath to use new API.
- Update test/datasets/* to use in-memory filesystem, simplifying test
and testing fsspec API.
  • Loading branch information
Tony Sherbondy committed Nov 14, 2023
1 parent 9f7e824 commit 63bb519
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 46 deletions.
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset:

@pytest.fixture(scope='session')
def get_dataset() -> Callable:
root = osp.join('/', 'tmp', 'pyg_test_datasets')
root = 'memory://' + osp.join('tmp', 'pyg_test_datasets')
yield functools.partial(load_dataset, root)
if osp.exists(root):
shutil.rmtree(root)
assert False, "Test leaked to local filesystem."


@pytest.fixture
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from torch_geometric.data.data import BaseData
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data.fs_utils import fs_normpath, fs_torch_save


IndexType = Union[slice, Tensor, np.ndarray, Sequence]

Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
super().__init__()

if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
root = osp.expanduser(fs_normpath(root))

self.root = root
self.transform = transform
Expand Down Expand Up @@ -244,9 +246,9 @@ def _process(self):
self.process()

path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(_repr(self.pre_transform), path)
fs_torch_save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(_repr(self.pre_filter), path)
fs_torch_save(_repr(self.pre_filter), path)

if self.log and 'pytest' not in sys.modules:
print('Done!', file=sys.stderr)
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import sys
import urllib
from typing import Optional
import fsspec

from torch_geometric.data.makedirs import makedirs
from torch_geometric.data.fs_utils import fs_exists


def download_url(
Expand All @@ -30,7 +32,7 @@ def download_url(

path = osp.join(folder, filename)

if osp.exists(path): # pragma: no cover
if fs_exists(path): # pragma: no cover
if log and 'pytest' not in sys.modules:
print(f'Using existing file {filename}', file=sys.stderr)
return path
Expand All @@ -43,7 +45,7 @@ def download_url(
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)

with open(path, 'wb') as f:
with fsspec.open(path, 'wb') as f:
# workaround for https://bugs.python.org/issue42853
while True:
chunk = data.read(10 * 1024 * 1024)
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/data/extract.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Optional
import bz2
import gzip
import os.path as osp
Expand All @@ -11,7 +12,7 @@ def maybe_log(path, log=True):
print(f'Extracting {path}', file=sys.stderr)


def extract_tar(path: str, folder: str, mode: str = 'r:gz', log: bool = True):
def extract_tar(path: str, folder: str, mode: str = 'r:gz', log: bool = True, fileobj: Optional[Any]=None):
r"""Extracts a tar archive to a specific folder.
Args:
Expand All @@ -22,7 +23,7 @@ def extract_tar(path: str, folder: str, mode: str = 'r:gz', log: bool = True):
console. (default: :obj:`True`)
"""
maybe_log(path, log)
with tarfile.open(path, mode) as f:
with tarfile.open(path, mode, fileobj=fileobj) as f:
f.extractall(folder)


Expand Down
115 changes: 115 additions & 0 deletions torch_geometric/data/fs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Any, Optional
import io
import os.path as osp
import sys
import torch
import fsspec
from fsspec.core import url_to_fs

DEFAULT_CACHE_PATH = '/tmp/pyg_simplecache'


def get_fs(path: str) -> fsspec.AbstractFileSystem:
return url_to_fs(path)[0]


def fs_normpath(path: str, *args) -> str:
if get_fs(path).protocol == 'file':
return osp.normpath(path, *args)
return path


def fs_exists(path: str) -> bool:
return get_fs(path).exists(path)


def fs_ls(path: str, detail: bool = False) -> bool:
fs = get_fs(path)
results = fs.ls(path, detail=detail)
# TODO: Strip common ancestor.
return [osp.basename(x) for x in results]


def fs_mkdirs(path: str, **kwargs):
return get_fs(path).mkdirs(path, **kwargs)


def fs_isdir(path: str) -> bool:
return get_fs(path).isdir(path)


def _copy_file_handle(f_from: Any, path: str, blocksize: int, **kwargs):
with fsspec.open(path, 'wb', **kwargs) as f_to:
data = True
while data:
data = f_from.read(blocksize)
f_to.write(data)


def fs_cp(path1: str, path2: str,
kwargs1: Optional[dict] = None,
kwargs2: Optional[dict] = None,
extract: bool = False,
cache_path: Optional[str] = DEFAULT_CACHE_PATH,
blocksize: int = 2 * 22):
# Initialize kwargs for source/destination.
kwargs1 = kwargs1 or {}
kwargs2 = kwargs2 or {}

# Cache result if the protocol is not local and we have a cache folder.
if get_fs(path1).protocol != 'file' and cache_path:
kwargs1 = {
**kwargs1,
'simplecache': {'cache_storage': cache_path},
}
path1 = f'simplecache::{path1}'

# Extract = Unarchive + Decompress if applicable.
if extract and path1.endswith('.tar.gz'):
kwargs1 = {**kwargs1, 'tar': {'compression': 'gzip'}}
path1 = f'tar://**::{path1}'
elif extract and path1.endswith('.zip'):
path1 = f'zip://**::{path1}'
else:
name = osp.basename(path1)
if extract and name.endswith('.gz') or name.endswith('.bz2'):
name = osp.splitext(name)[0]
path2 = osp.join(path2, name)

if '*' in path1:
open_files = fsspec.open_files(path1, **kwargs1)
with open_files as of:
for f_from, open_file in zip(of, open_files):
to_path = osp.join(path2, open_file.path)
_copy_file_handle(f_from, to_path, blocksize, **kwargs2)
else:
with fsspec.open(path1, compression='infer', **kwargs1) as f_from:
_copy_file_handle(f_from, path2, blocksize, **kwargs2)


def fs_rm(path: str, **kwargs):
get_fs(path).rm(path, **kwargs)


def fs_mv(path1: str, path2: str, **kwargs):
fs1 = get_fs(path1)
fs2 = get_fs(path2)
assert fs1.protocol == fs2.protocol
fs1.mv(path1, path2, **kwargs)


def fs_glob(path: str, **kwargs):
fs = get_fs(path)
return [fs.unstrip_protocol(x) for x in fs.glob(path, **kwargs)]


def fs_torch_save(data: Any, path: str):
buffer = io.BytesIO()
torch.save(data, buffer)
with fsspec.open(path, 'wb') as f:
f.write(buffer.getvalue())


def fs_torch_load(path: str, map_location: Any = None) -> Any:
with fsspec.open(path, 'rb') as f:
return torch.load(f, map_location)
5 changes: 3 additions & 2 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch_geometric.data.data import BaseData
from torch_geometric.data.dataset import Dataset, IndexType
from torch_geometric.data.separate import separate
from torch_geometric.data.fs_utils import fs_torch_save, fs_torch_load


class InMemoryDataset(Dataset, ABC):
Expand Down Expand Up @@ -121,11 +122,11 @@ def get(self, idx: int) -> BaseData:
def save(cls, data_list: List[BaseData], path: str):
r"""Saves a list of data objects to the file path :obj:`path`."""
data, slices = cls.collate(data_list)
torch.save((data.to_dict(), slices), path)
fs_torch_save((data.to_dict(), slices), path)

def load(self, path: str, data_cls: Type[BaseData] = Data):
r"""Loads the dataset from the file path :obj:`path`."""
data, self.slices = torch.load(path)
data, self.slices = fs_torch_load(path)
if isinstance(data, dict): # Backward compatibility.
data = data_cls.from_dict(data)
self.data = data
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/data/makedirs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import errno
import os
import os.path as osp
from torch_geometric.data.fs_utils import fs_normpath, fs_mkdirs, fs_isdir


def makedirs(path: str):
Expand All @@ -10,7 +11,7 @@ def makedirs(path: str):
path (str): The path to create.
"""
try:
os.makedirs(osp.expanduser(osp.normpath(path)))
except OSError as e:
if e.errno != errno.EEXIST and osp.isdir(path):
fs_mkdirs(osp.expanduser(fs_normpath(path)))
except FileExistsError as e:
if not fs_isdir(path):
raise e
31 changes: 11 additions & 20 deletions torch_geometric/datasets/snap_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import os
import os.path as osp
from typing import Any, Callable, List, Optional
import fsspec

import numpy as np
import torch

from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_gz,
extract_tar,
)
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data.fs_utils import fs_cp, fs_ls, fs_isdir
from torch_geometric.utils import coalesce


Expand All @@ -35,7 +31,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:
]
for i in range(4, len(files), 5):
featnames_file = files[i]
with open(featnames_file, 'r') as f:
with fsspec.open(featnames_file, 'r') as f:
featnames = f.read().split('\n')[:-1]
featnames = [' '.join(x.split(' ')[1:]) for x in featnames]
all_featnames += featnames
Expand Down Expand Up @@ -63,7 +59,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:

# Reorder `x` according to `featnames` ordering.
x_all = torch.zeros(x.size(0), len(all_featnames))
with open(featnames_file, 'r') as f:
with fsspec.open(featnames_file, 'r') as f:
featnames = f.read().split('\n')[:-1]
featnames = [' '.join(x.split(' ')[1:]) for x in featnames]
indices = [all_featnames[featname] for featname in featnames]
Expand All @@ -79,7 +75,7 @@ def read_ego(files: List[str], name: str) -> List[EgoData]:

circles = []
circles_batch = []
with open(circles_file, 'r') as f:
with fsspec.open(circles_file, 'r') as f:
for i, circle in enumerate(f.read().split('\n')[:-1]):
circle = [idx_assoc[c] for c in circle.split()[1:]]
circles += circle
Expand Down Expand Up @@ -219,28 +215,23 @@ def processed_file_names(self) -> str:
return 'data.pt'

def _download(self):
if osp.isdir(self.raw_dir) and len(os.listdir(self.raw_dir)) > 0:
if fs_isdir(self.raw_dir) and len(fs_ls(self.raw_dir)) > 0:
return

makedirs(self.raw_dir)
self.download()

def download(self):
for name in self.available_datasets[self.name]:
path = download_url(f'{self.url}/{name}', self.raw_dir)
if name.endswith('.tar.gz'):
extract_tar(path, self.raw_dir)
elif name.endswith('.gz'):
extract_gz(path, self.raw_dir)
os.unlink(path)
fs_cp(f'{self.url}/{name}', self.raw_dir, extract=True)

def process(self):
raw_dir = self.raw_dir
filenames = os.listdir(self.raw_dir)
if len(filenames) == 1 and osp.isdir(osp.join(raw_dir, filenames[0])):
filenames = fs_ls(self.raw_dir)
if len(filenames) == 1 and fs_isdir(osp.join(raw_dir, filenames[0])):
raw_dir = osp.join(raw_dir, filenames[0])

raw_files = sorted([osp.join(raw_dir, f) for f in os.listdir(raw_dir)])
raw_files = sorted([osp.join(raw_dir, f) for f in fs_ls(raw_dir)])

if self.name[:4] == 'ego-':
data_list = read_ego(raw_files, self.name[4:])
Expand Down
17 changes: 8 additions & 9 deletions torch_geometric/datasets/tu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip,
)
from torch_geometric.io import read_tu_data
from torch_geometric.data.fs_utils import fs_torch_load, fs_cp, fs_mv, fs_torch_save


class TUDataset(InMemoryDataset):
Expand Down Expand Up @@ -138,7 +137,7 @@ def __init__(
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)

out = torch.load(self.processed_paths[0])
out = fs_torch_load(self.processed_paths[0])
if not isinstance(out, tuple) or len(out) != 3:
raise RuntimeError(
"The 'data' object was created by an older version of PyG. "
Expand Down Expand Up @@ -193,11 +192,11 @@ def processed_file_names(self) -> str:
def download(self):
url = self.cleaned_url if self.cleaned else self.url
folder = osp.join(self.root, self.name)
path = download_url(f'{url}/{self.name}.zip', folder)
extract_zip(path, folder)
os.unlink(path)
shutil.rmtree(self.raw_dir)
os.rename(osp.join(folder, self.name), self.raw_dir)
zip_url = f'{url}/{self.name}.zip'
fs_cp(zip_url, self.raw_dir, extract=True)
# fs_rm(self.raw_dir , recursive=True)
# fs_cp(f'zip://{self.name}/**::{zip_url}', self.raw_dir)
fs_mv(osp.join(self.raw_dir, self.name), self.raw_dir, recursive=True)

def process(self):
self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)
Expand All @@ -214,7 +213,7 @@ def process(self):
self.data, self.slices = self.collate(data_list)
self._data_list = None # Reset cache.

torch.save((self._data.to_dict(), self.slices, sizes),
fs_torch_save((self._data.to_dict(), self.slices, sizes),
self.processed_paths[0])

def __repr__(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/io/planetoid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path as osp
import warnings
from itertools import repeat
import fsspec

import torch

Expand Down Expand Up @@ -91,7 +92,7 @@ def read_file(folder, prefix, name):
if name == 'test.index':
return read_txt_array(path, dtype=torch.long)

with open(path, 'rb') as f:
with fsspec.open(path, 'rb') as f:
warnings.filterwarnings('ignore', '.*`scipy.sparse.csr` name.*')
out = pickle.load(f, encoding='latin1')

Expand Down
Loading

0 comments on commit 63bb519

Please sign in to comment.