-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathunion.go
143 lines (122 loc) · 2.85 KB
/
union.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package ep
import (
"context"
"fmt"
)
var _ = registerGob(union([]Runner{}))
// Union returns a new composite Runner that dispatches its inputs to all of
// its internal runners and collects their output into a single unified stream
// of datasets. It is required the all of the individual runners returns the
// same data types
func Union(runners ...Runner) (Runner, error) {
if len(runners) == 0 {
err := fmt.Errorf("at least 1 runner is required for union")
return nil, err
} else if len(runners) == 1 {
return runners[0], nil
}
u := union(runners)
_, err := u.returnsErr()
if err != nil {
return nil, err
}
return u, nil
}
type union []Runner
func (rs union) Equals(other interface{}) bool {
r, ok := other.(union)
if !ok || len(rs) != len(r) {
return false
}
for i, cur := range rs {
if !cur.Equals(r[i]) {
return false
}
}
return true
}
// see Runner. Assumes all runners has the same return types.
func (rs union) Returns() []Type {
types, err := rs.returnsErr()
if err != nil {
panic("Union() should've prevented this error from panicking")
}
return types
}
// determine the return types by verifying all runners return same types
func (rs union) returnsErr() ([]Type, error) {
types := rs[0].Returns()
// ensure that all return types are compatible
for _, r := range rs {
have := r.Returns()
if len(have) != len(types) {
return nil, fmt.Errorf("mismatch number of columns: %v and %v", types, have)
}
if !AreEqualTypes(types, have) {
return nil, fmt.Errorf("type mismatch %s and %s", types, have)
}
}
return types, nil
}
func (rs union) Run(ctx context.Context, inp, out chan Dataset) (err error) {
// start all inner runners
inputs := make([]chan Dataset, len(rs))
outputs := make([]chan Dataset, len(rs))
errors := make([]error, len(rs))
defer func() {
for _, errI := range errors {
if errI != nil && err == nil {
err = errI
break
}
}
}()
for i := range rs {
inputs[i] = make(chan Dataset)
outputs[i] = make(chan Dataset)
go Run(ctx, rs[i], inputs[i], outputs[i], nil, &errors[i])
}
// fork the input to all inner runners
go func() {
for data := range inp {
for _, s := range inputs {
s <- data
}
}
// close all inner runners
for _, s := range inputs {
close(s)
}
}()
// collect and union all of the stream into a single output
for _, s := range outputs {
for data := range s {
out <- data
}
}
return err
}
func (rs union) Scopes() StringsSet {
scopes := make(StringsSet)
for _, r := range rs {
if sr, ok := r.(ScopesRunner); ok {
scopes.AddAll(sr.Scopes())
}
}
return scopes
}
func (rs union) ApproxSize() int {
var total int
for _, r := range rs {
approxSizer, ok := r.(ApproxSizer)
if !ok {
return UnknownSize
}
size := approxSizer.ApproxSize()
if size == UnknownSize {
return UnknownSize
}
total += size
}
return total
}