forked from Lin-Yijie/Graph-Matching-Networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsegment.py
59 lines (42 loc) · 2.33 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
def segment_sum(data, segment_ids):
"""
Analogous to tf.segment_sum (https://www.tensorflow.org/api_docs/python/tf/math/segment_sum).
:param data: A pytorch tensor of the data for segmented summation.
:param segment_ids: A 1-D tensor containing the indices for the segmentation.
:return: a tensor of the same type as data containing the results of the segmented summation.
"""
if not all(segment_ids[i] <= segment_ids[i + 1] for i in range(len(segment_ids) - 1)):
raise AssertionError("elements of segment_ids must be sorted")
if len(segment_ids.shape) != 1:
raise AssertionError("segment_ids have be a 1-D tensor")
if data.shape[0] != segment_ids.shape[0]:
raise AssertionError("segment_ids should be the same size as dimension 0 of input.")
num_segments = len(torch.unique(segment_ids))
return unsorted_segment_sum(data, segment_ids, num_segments)
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
:param data: A tensor whose segments are to be summed.
:param segment_ids: The segment indices tensor.
:param num_segments: The number of segments.
:return: A tensor of same data type as the data argument.
"""
assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
# Encourage to use the below code when a deterministic result is
# needed (reproducibility). However, the code below is with low efficiency.
# tensor = torch.zeros(num_segments, data.shape[1]).cuda()
# for index in range(num_segments):
# tensor[index, :] = torch.sum(data[segment_ids == index, :], dim=0)
# return tensor
if len(segment_ids.shape) == 1:
s = torch.prod(torch.tensor(data.shape[1:])).long().cuda()
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
shape = [num_segments] + list(data.shape[1:])
tensor = torch.zeros(*shape).cuda().scatter_add(0, segment_ids, data)
tensor = tensor.type(data.dtype)
return tensor