Skip to content

Commit

Permalink
Changes in Temporal Data to support a new Temporal Data Loader (#3985)
Browse files Browse the repository at this point in the history
* Refactor TemporalData class to inherit from BaseData

* Fixes to get TemporalData working

* Small fixes in __delitem__ of TemporalData

* Add batch, __cat_dim__ and __inc__ to TemporalData

* Add Docs to TemporalData

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Lint fixes

* Add removed method TemporalData.seq_batches

* Update torch_geometric/data/temporal.py

Co-authored-by: Matthias Fey <[email protected]>

* Update torch_geometric/data/temporal.py

Co-authored-by: Matthias Fey <[email protected]>

* Changes requested in review

* Removing trailing whitespace

* fix doc + some inheritance issues

* fix iter

* Add the new TemporalDataset class and refactor and
- refactor TemporalData to work with the default implementation of DataLoader
- refactor Jodie Dataset to inherit from TemporalDataset
- refactor the TGN example to work with DataLoader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Lint fixes

* Lint fixes

* Fix tests

* Add TemporalDataLoader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes suggested in code review.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Lint fix

* Update docs

* Update torch_geometric/data/temporal.py

Co-authored-by: Matthias Fey <[email protected]>

* Update torch_geometric/loader/temporal_dataloader.py

Co-authored-by: Matthias Fey <[email protected]>

* Add __init_ for TemporalDataLoader and update docs.

* update example

* update dataloader

* update data

* update data (part 2)

* update data (part 3)

* bugfix

* temporal dataloader test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
3 people authored Mar 5, 2022
1 parent e18a897 commit aa99b50
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 66 deletions.
34 changes: 23 additions & 11 deletions examples/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (IdentityMessage, LastAggregator,
LastNeighborLoader)
Expand All @@ -26,14 +27,21 @@

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0].to(device)
data = dataset[0]

# For small datasets, we can put the whole dataset on GPU and thus avoid
# expensive memory transfer costs for mini-batches:
data = data.to(device)

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)

train_loader = TemporalDataLoader(train_data, batch_size=200)
val_loader = TemporalDataLoader(val_data, batch_size=200)
test_loader = TemporalDataLoader(test_data, batch_size=200)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)


Expand Down Expand Up @@ -103,7 +111,8 @@ def train():
neighbor_loader.reset_state() # Start with an empty graph.

total_loss = 0
for batch in train_data.seq_batches(batch_size=200):
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()

src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
Expand All @@ -118,7 +127,8 @@ def train():

# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
Expand All @@ -139,15 +149,16 @@ def train():


@torch.no_grad()
def test(inference_data):
def test(loader):
memory.eval()
gnn.eval()
link_pred.eval()

torch.manual_seed(12345) # Ensure deterministic sampling across epochs.

aps, aucs = [], []
for batch in inference_data.seq_batches(batch_size=200):
for batch in loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

neg_dst = torch.randint(min_dst_idx, max_dst_idx + 1, (src.size(0), ),
Expand All @@ -158,7 +169,8 @@ def test(inference_data):
assoc[n_id] = torch.arange(n_id.size(0), device=device)

z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
Expand All @@ -179,8 +191,8 @@ def test(inference_data):

for epoch in range(1, 51):
loss = train()
print(f' Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_data)
test_ap, test_auc = test(test_data)
print(f' Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_loader)
test_ap, test_auc = test(test_loader)
print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')
22 changes: 22 additions & 0 deletions test/loader/test_temporal_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader


def test_temporal_dataloader():
src = dst = t = torch.arange(10)
msg = torch.randn(10, 16)

data = TemporalData(src=src, dst=dst, t=t, msg=msg)

loader = TemporalDataLoader(data, batch_size=2)
assert len(loader) == 5

for i, batch in enumerate(loader):
assert len(batch) == 2
arange = range(len(batch) * i, len(batch) * i + len(batch))
assert batch.src.tolist() == data.src[arange].tolist()
assert batch.dst.tolist() == data.dst[arange].tolist()
assert batch.t.tolist() == data.t[arange].tolist()
assert batch.msg.tolist() == data.msg[arange].tolist()
11 changes: 5 additions & 6 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData, Data
from torch_geometric.data.dataset import IndexType
from torch_geometric.data.separate import separate

Expand Down Expand Up @@ -54,7 +54,7 @@ class Batch(metaclass=DynamicInheritance):
:obj:`batch`, which maps each node to its respective graph identifier.
"""
@classmethod
def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]],
def from_data_list(cls, data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None):
r"""Constructs a :class:`~torch_geometric.data.Batch` object from a
Expand All @@ -80,7 +80,7 @@ def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]],

return batch

def get_example(self, idx: int) -> Union[Data, HeteroData]:
def get_example(self, idx: int) -> BaseData:
r"""Gets the :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`.
The :class:`~torch_geometric.data.Batch` object must have been created
Expand All @@ -103,8 +103,7 @@ def get_example(self, idx: int) -> Union[Data, HeteroData]:

return data

def index_select(self,
idx: IndexType) -> Union[List[Data], List[HeteroData]]:
def index_select(self, idx: IndexType) -> List[BaseData]:
r"""Creates a subset of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from specified
indices :obj:`idx`.
Expand Down Expand Up @@ -152,7 +151,7 @@ def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
else:
return self.index_select(idx)

def to_data_list(self) -> Union[List[Data], List[HeteroData]]:
def to_data_list(self) -> List[BaseData]:
r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from the
:class:`~torch_geometric.data.Batch` object.
Expand Down
88 changes: 47 additions & 41 deletions torch_geometric/data/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,49 +93,26 @@ def __init__(
for key, value in kwargs.items():
setattr(self, key, value)

@staticmethod
def __prepare_non_str_idx(idx):
if isinstance(idx, int):
idx = torch.tensor([idx])
if isinstance(idx, (list, tuple)):
idx = torch.tensor(idx)
elif isinstance(idx, slice):
pass
elif isinstance(idx, torch.Tensor) and (idx.dtype == torch.long
or idx.dtype == torch.bool):
pass
else:
raise IndexError(
f'Only strings, integers, slices (`:`), list, tuples, and '
f'long or bool tensors are valid indices (got '
f'{type(idx).__name__}).')
return idx
def index_select(self, idx: Any) -> 'TemporalData':
idx = prepare_idx(idx)
data = copy.copy(self)
for key, value in data._store.items():
if value.size(0) == self.num_events:
data[key] = value[idx]
return data

def __getitem__(self, idx: Any) -> Any:
if isinstance(idx, str):
return self._store[idx]
return self.index_select(idx)

prepared_idx = self.__prepare_non_str_idx(idx)

data = copy.copy(self)
for key, item in data:
if item.size(0) == self.num_events:
data[key] = item[prepared_idx]
return data

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any):
"""Sets the attribute :obj:`key` to :obj:`value`."""
self._store[key] = value

def __delitem__(self, idx):
if isinstance(idx, str) and idx in self._store:
del self._store[idx]

prepared_idx = self.__prepare_non_str_idx(idx)

for key, item in self:
if item.shape[0] == self.num_events:
del item[prepared_idx]
def __delitem__(self, key: str):
if key in self._store:
del self._store[key]

def __getattr__(self, key: str) -> Any:
if '_store' not in self.__dict__:
Expand All @@ -153,8 +130,11 @@ def __delattr__(self, key: str):
delattr(self._store, key)

def __iter__(self) -> Iterable:
for key, value in self._store.items():
yield key, value
for i in range(self.num_events):
yield self[i]

def __len__(self) -> int:
return self.num_events

def __call__(self, *args: List[str]) -> Iterable:
for key, value in self._store.items(*args):
Expand Down Expand Up @@ -247,6 +227,16 @@ def __repr__(self) -> str:

def train_val_test_split(self, val_ratio: float = 0.15,
test_ratio: float = 0.15):
r"""Splits the data in training, validation and test sets according to
time.
Args:
val_ratio (float, optional): The proportion (in percents) of the
dataset to include in the validation split.
(default: :obj:`0.15`)
test_ratio (float, optional): The proportion (in percents) of the
dataset to include in the test split. (default: :obj:`0.15`)
"""
val_time, test_time = np.quantile(
self.t.cpu().numpy(),
[1. - val_ratio - test_ratio, 1. - test_ratio])
Expand All @@ -256,10 +246,6 @@ def train_val_test_split(self, val_ratio: float = 0.15,

return self[:val_idx], self[val_idx:test_idx], self[test_idx:]

def seq_batches(self, batch_size: int):
for start in range(0, self.num_events, batch_size):
yield self[start:start + batch_size]

###########################################################################

def coalesce(self):
Expand All @@ -276,3 +262,23 @@ def is_undirected(self) -> bool:

def is_directed(self) -> bool:
raise NotImplementedError


###############################################################################


def prepare_idx(idx):
if isinstance(idx, int):
return slice(idx, idx + 1)
if isinstance(idx, (list, tuple)):
return torch.tensor(idx)
elif isinstance(idx, slice):
return idx
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.long:
return idx
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
return idx

raise IndexError(
f"Only strings, integers, slices (`:`), list, tuples, and long or "
f"bool tensors are valid indices (got '{type(idx).__name__}')")
2 changes: 2 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .data_list_loader import DataListLoader
from .dense_data_loader import DenseDataLoader
from .neighbor_sampler import NeighborSampler
from .temporal_dataloader import TemporalDataLoader

__all__ = [
'DataLoader',
Expand All @@ -25,6 +26,7 @@
'DataListLoader',
'DenseDataLoader',
'NeighborSampler',
'TemporalDataLoader',
]

classes = __all__
5 changes: 3 additions & 2 deletions torch_geometric/loader/data_list_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch

from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.data import Dataset
from torch_geometric.data.data import BaseData


def collate_fn(data_list):
Expand Down Expand Up @@ -30,7 +31,7 @@ class DataListLoader(torch.utils.data.DataLoader):
:class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or
:obj:`num_workers`.
"""
def __init__(self, dataset: Union[Dataset, List[Data], List[HeteroData]],
def __init__(self, dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1, shuffle: bool = False, **kwargs):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Batch, Data, Dataset, HeteroData
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData


class Collater:
Expand All @@ -14,7 +15,7 @@ def __init__(self, follow_batch, exclude_keys):

def __call__(self, batch):
elem = batch[0]
if isinstance(elem, (Data, HeteroData)):
if isinstance(elem, BaseData):
return Batch.from_data_list(batch, self.follow_batch,
self.exclude_keys)
elif isinstance(elem, torch.Tensor):
Expand Down Expand Up @@ -59,7 +60,7 @@ class DataLoader(torch.utils.data.DataLoader):
"""
def __init__(
self,
dataset: Union[Dataset, List[Data], List[HeteroData]],
dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = None,
Expand Down
37 changes: 37 additions & 0 deletions torch_geometric/loader/temporal_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import List

import torch

from torch_geometric.data import TemporalData


class TemporalDataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges succesive events of a
:class:`torch_geometric.data.TemporalData` to a mini-batch.
Args:
data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`
from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
if 'shuffle' in kwargs:
del kwargs['shuffle']

self.data = data
self.events_per_batch = batch_size

if kwargs.get('drop_last', False) and len(data) % batch_size != 0:

This comment has been minimized.

Copy link
@Everyday-seu

Everyday-seu Apr 19, 2022

Is len(data) correct? In the example you gave in tgn:
train_data, val_data, test_data = data.train_val_test_split(val_ratio=0.15, test_ratio=0.15)
train_loader = TemporalDataLoader(train_data, batch_size=200)
But len(train_data)=5,so len(data)=5 and it turns to the "arange" becomes(0,5,batch_size). Is it correct to use something like data.num_events?
Hope for your reply! Thanks!

This comment has been minimized.

Copy link
@rusty1s

rusty1s Apr 19, 2022

Author Member

See #4499.

arange = range(0, len(data) - batch_size, batch_size)
else:
arange = range(0, len(data), batch_size)

super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)

def __call__(self, arange: List[int]) -> TemporalData:
return self.data[arange[0]:arange[0] + self.events_per_batch]
Loading

0 comments on commit aa99b50

Please sign in to comment.