-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdmc.go
252 lines (221 loc) · 5.1 KB
/
dmc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
package main
import (
"bufio"
"bytes"
"flag"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
"golang.org/x/crypto/ssh/terminal"
)
// dmc runs the command on all hosts passed via stdin simultaneously
const (
white = iota + 89
black
red
green
yellow
blue
purple
)
var tty = terminal.IsTerminal(int(os.Stdout.Fd()))
var last int64 = white
func cycle() int {
col := atomic.LoadInt64(&last)
atomic.AddInt64(&last, 1)
return int(col)
}
func color(s string, color int, bold bool) string {
if !tty {
return s
}
b := "01;"
if !bold {
b = ""
}
return fmt.Sprintf("\033[%s%dm%s\033[0m", b, color, s)
}
var cfg struct {
verbose bool
interleave bool
compress bool
quiet bool
prefix string
hosts string
dns string
threads int
ssh string
}
func init() {
flag.BoolVar(&cfg.verbose, "v", false, "verbose output")
flag.BoolVar(&cfg.quiet, "q", false, "do not add host prefixes to command output")
flag.BoolVar(&cfg.compress, "C", false, "enable transparent ssh compression")
flag.StringVar(&cfg.prefix, "p", "", "prefix for command echo")
flag.StringVar(&cfg.hosts, "hosts", "", "list of hosts")
flag.StringVar(&cfg.dns, "d", "", "dns name for multi-hosts")
flag.IntVar(&cfg.threads, "n", 512, "threads to run in parallel")
flag.BoolVar(&cfg.interleave, "i", false, "interleave output as it is available")
flag.StringVar(&cfg.ssh, "ssh", "ssh", "remote shell command (default: ssh)")
flag.Parse()
if cfg.ssh != "ssh" && len(os.Getenv("DMC_SSH")) > 0 {
cfg.ssh = os.Getenv("DMC_SSH")
}
}
func vprintf(format string, args ...interface{}) {
if cfg.verbose {
fmt.Printf(format, args...)
}
}
// hostStr returns a host string. If quiet mode is on it returns the empty string.
func hostStr(host string, c int, bold bool) string {
if cfg.quiet {
return ""
}
return fmt.Sprintf("[%s]", color(host, c, bold))
}
func getHosts() []string {
if len(cfg.hosts) > 0 {
return strings.Split(cfg.hosts, ",")
}
if len(cfg.dns) > 0 {
hosts, err := net.LookupHost(cfg.dns)
if err != nil {
fmt.Printf("Error looking up %s: %s\n", cfg.dns, err)
os.Exit(-1)
}
return hosts
}
var hosts []string
fi, _ := os.Stdin.Stat()
if (fi.Mode() & os.ModeCharDevice) != 0 {
fmt.Println("usage: you must pipe a list of hosts into dmc or use -hosts.")
return hosts
}
s := bufio.NewScanner(os.Stdin)
for s.Scan() {
hosts = append(hosts, strings.Trim(s.Text(), "\n"))
}
if err := s.Err(); err != nil {
fmt.Printf("Error reading from stdin: %s\n", err)
}
return hosts
}
func ssh(host, cmd string) *exec.Cmd {
cc := strings.Split(cfg.ssh, " ")
if cfg.compress {
cc = append(cc, "-C")
}
cc = append(cc, host)
if cfg.ssh != "ssh" {
cc = append(cc, strings.Split(cmd, " ")...)
} else {
cc = append(cc, cmd)
}
return exec.Command(cc[0], cc[1:]...)
}
// do runs cmd on host, writing its output to out.
func do(host, cmd string) ([]byte, error) {
c := ssh(host, cmd)
vprintf("cmd: %+v\n", c.Args)
output, err := c.CombinedOutput()
var buf bytes.Buffer
if err != nil {
fmt.Fprintf(&buf, "%s%s$ %s: Error: %s\n", cfg.prefix, hostStr(host, red, true), cmd, err)
if len(output) > 0 {
buf.Write(output)
}
return buf.Bytes(), err
}
fmt.Fprintf(&buf, "%s%s$ %s\n%s", cfg.prefix, hostStr(host, green, true), cmd, string(output))
return buf.Bytes(), nil
}
type LineWriter interface {
WriteLine(string) error
}
// doi runs cmd on host, writing lines to out as available. It cycles through
// colors so that hosts can be differentiated as well as possible.
func doi(host, cmd string, out LineWriter) error {
c := ssh(host, cmd)
rdr, wrt := io.Pipe()
c.Stdout = wrt
c.Stderr = wrt
var err error
go func() {
err = c.Run()
wrt.Close()
}()
col := cycle()
r := bufio.NewScanner(rdr)
for r.Scan() {
out.WriteLine(fmt.Sprintf("%s%s\n", hostStr(host, col, false), string(r.Bytes())))
}
return err
}
func main() {
args := flag.Args()
if len(args) == 0 {
fmt.Println("usage: dmc <command>")
return
}
hosts := getHosts()
cmd := strings.Join(args, " ")
vprintf("Running `%s` on %d hosts\n", cmd, len(hosts))
par := cfg.threads
if par > len(hosts) {
par = len(hosts)
}
// output and input channels
output := make(chan string, par)
hostch := make(chan string, par)
var code int64
// use par as breadth of parallelism
var wg, outwg sync.WaitGroup
wg.Add(par)
for i := 0; i < par; i++ {
go func() {
// if we're interleaving output it's slightly different
// so we just branch here
if cfg.interleave {
stdout := NewSyncLineWriter(os.Stdout)
for host := range hostch {
err := doi(host, cmd, stdout)
if err != nil {
atomic.StoreInt64(&code, 1)
}
}
} else {
for host := range hostch {
out, err := do(host, cmd)
output <- string(out)
if err != nil {
atomic.StoreInt64(&code, 1)
}
}
}
wg.Done()
}()
}
go func() {
wg.Wait()
close(output)
}()
// print output as it comes in
outwg.Add(1)
go func() {
for o := range output {
fmt.Print(o)
}
outwg.Done()
}()
for _, host := range hosts {
hostch <- host
}
close(hostch)
outwg.Wait()
os.Exit(int(code))
}