-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changes in Temporal Data to support a new Temporal Data Loader (#3985)
* Refactor TemporalData class to inherit from BaseData * Fixes to get TemporalData working * Small fixes in __delitem__ of TemporalData * Add batch, __cat_dim__ and __inc__ to TemporalData * Add Docs to TemporalData * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint fixes * Add removed method TemporalData.seq_batches * Update torch_geometric/data/temporal.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/data/temporal.py Co-authored-by: Matthias Fey <[email protected]> * Changes requested in review * Removing trailing whitespace * fix doc + some inheritance issues * fix iter * Add the new TemporalDataset class and refactor and - refactor TemporalData to work with the default implementation of DataLoader - refactor Jodie Dataset to inherit from TemporalDataset - refactor the TGN example to work with DataLoader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint fixes * Lint fixes * Fix tests * Add TemporalDataLoader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes suggested in code review. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Lint fix * Update docs * Update torch_geometric/data/temporal.py Co-authored-by: Matthias Fey <[email protected]> * Update torch_geometric/loader/temporal_dataloader.py Co-authored-by: Matthias Fey <[email protected]> * Add __init_ for TemporalDataLoader and update docs. * update example * update dataloader * update data * update data (part 2) * update data (part 3) * bugfix * temporal dataloader test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey <[email protected]>
- Loading branch information
1 parent
e18a897
commit aa99b50
Showing
9 changed files
with
146 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
|
||
from torch_geometric.data import TemporalData | ||
from torch_geometric.loader import TemporalDataLoader | ||
|
||
|
||
def test_temporal_dataloader(): | ||
src = dst = t = torch.arange(10) | ||
msg = torch.randn(10, 16) | ||
|
||
data = TemporalData(src=src, dst=dst, t=t, msg=msg) | ||
|
||
loader = TemporalDataLoader(data, batch_size=2) | ||
assert len(loader) == 5 | ||
|
||
for i, batch in enumerate(loader): | ||
assert len(batch) == 2 | ||
arange = range(len(batch) * i, len(batch) * i + len(batch)) | ||
assert batch.src.tolist() == data.src[arange].tolist() | ||
assert batch.dst.tolist() == data.dst[arange].tolist() | ||
assert batch.t.tolist() == data.t[arange].tolist() | ||
assert batch.msg.tolist() == data.msg[arange].tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import List | ||
|
||
import torch | ||
|
||
from torch_geometric.data import TemporalData | ||
|
||
|
||
class TemporalDataLoader(torch.utils.data.DataLoader): | ||
r"""A data loader which merges succesive events of a | ||
:class:`torch_geometric.data.TemporalData` to a mini-batch. | ||
Args: | ||
data (TemporalData): The :obj:`~torch_geometric.data.TemporalData` | ||
from which to load the data. | ||
batch_size (int, optional): How many samples per batch to load. | ||
(default: :obj:`1`) | ||
**kwargs (optional): Additional arguments of | ||
:class:`torch.utils.data.DataLoader`. | ||
""" | ||
def __init__(self, data: TemporalData, batch_size: int = 1, **kwargs): | ||
if 'collate_fn' in kwargs: | ||
del kwargs['collate_fn'] | ||
if 'shuffle' in kwargs: | ||
del kwargs['shuffle'] | ||
|
||
self.data = data | ||
self.events_per_batch = batch_size | ||
|
||
if kwargs.get('drop_last', False) and len(data) % batch_size != 0: | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
arange = range(0, len(data) - batch_size, batch_size) | ||
else: | ||
arange = range(0, len(data), batch_size) | ||
|
||
super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs) | ||
|
||
def __call__(self, arange: List[int]) -> TemporalData: | ||
return self.data[arange[0]:arange[0] + self.events_per_batch] |
Oops, something went wrong.
Is len(data) correct? In the example you gave in tgn:
train_data, val_data, test_data = data.train_val_test_split(val_ratio=0.15, test_ratio=0.15)
train_loader = TemporalDataLoader(train_data, batch_size=200)
But len(train_data)=5,so len(data)=5 and it turns to the "arange" becomes(0,5,batch_size). Is it correct to use something like data.num_events?
Hope for your reply! Thanks!