diff --git a/cmd/guac/guac.go b/cmd/guac/guac.go index 515b8ef..4065cbd 100644 --- a/cmd/guac/guac.go +++ b/cmd/guac/guac.go @@ -9,46 +9,98 @@ import ( "net/url" "strconv" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/wwt/guac" ) +var tunnels map[string]guac.Tunnel + func main() { logrus.SetLevel(logrus.DebugLevel) - servlet := guac.NewServer(DemoDoConnect) + // servlet := guac.NewServer(DemoDoConnect) wsServer := guac.NewWebsocketServer(DemoDoConnect) + wsServerIntercept := guac.NewWebsocketServer(DemoDoConnectWithIntercept) sessions := guac.NewMemorySessionStore() - wsServer.OnConnect = sessions.Add - wsServer.OnDisconnect = sessions.Delete - - mux := http.NewServeMux() - mux.Handle("/tunnel", servlet) - mux.Handle("/tunnel/", servlet) - mux.Handle("/websocket-tunnel", wsServer) - mux.HandleFunc("/sessions/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + wsServerIntercept.OnConnect = sessions.Add + wsServerIntercept.OnDisconnect = sessions.Delete + + tunnels = make(map[string]guac.Tunnel) + + wsServerIntercept.OnConnectWs = func(s string, _ *websocket.Conn, _ *http.Request, t guac.Tunnel) { + tunnels[s] = t + } + + wsServerIntercept.OnDisconnectWs = func(s string, _ *websocket.Conn, _ *http.Request, _ guac.Tunnel) { + delete(tunnels, s) + } + + m := mux.NewRouter() + + // m.Handle("/", servlet) + m.Handle("/websocket-tunnel", wsServer) + m.Handle("/websocket-tunnel-intercept", wsServerIntercept) + + m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Disposition", "attachment") + t := mux.Vars(r)["tunnel"] + + tunnel, ok := tunnels[t] + if !ok { + w.Write([]byte("KO")) + w.WriteHeader(http.StatusInternalServerError) + return + } + + sit, ok := tunnel.(*guac.UserTunnel) + if !ok { + w.Write([]byte("Not supported")) + w.WriteHeader(http.StatusBadRequest) + return + } + + stream := mux.Vars(r)["stream"] + + if err := sit.InterceptOutputStream(stream, w); err != nil { + w.Write([]byte("KO Intercepting output stream")) + } + }).Methods("GET") + + m.HandleFunc("/api/session/tunnels/{tunnel}/streams/{stream}/{file}", func(w http.ResponseWriter, r *http.Request) { + t := mux.Vars(r)["tunnel"] + tunnel, ok := tunnels[t] + if !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("KO")) + return + } + + sit, ok := tunnel.(*guac.UserTunnel) + if !ok { + w.Write([]byte("Not supported")) + w.WriteHeader(http.StatusBadRequest) + return + } - sessions.RLock() - defer sessions.RUnlock() + stream := mux.Vars(r)["stream"] - type ConnIds struct { - Uuid string `json:"uuid"` - Num int `json:"num"` + if err := sit.InterceptInputStream(stream, r.Body); err != nil { + w.Write([]byte("KO intercepting input stream")) } + }).Methods("POST") - connIds := make([]*ConnIds, len(sessions.ConnIds)) + m.HandleFunc("/api/session/tunnels", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") - i := 0 - for id, num := range sessions.ConnIds { - connIds[i] = &ConnIds{ - Uuid: id, - Num: num, - } + t := []string{} + for tun := range tunnels { + t = append(t, tun) } - if err := json.NewEncoder(w).Encode(connIds); err != nil { + if err := json.NewEncoder(w).Encode(t); err != nil { logrus.Error(err) } }) @@ -57,7 +109,7 @@ func main() { s := &http.Server{ Addr: "0.0.0.0:4567", - Handler: mux, + Handler: m, ReadTimeout: guac.SocketTimeout, WriteTimeout: guac.SocketTimeout, MaxHeaderBytes: 1 << 20, @@ -69,7 +121,7 @@ func main() { } // DemoDoConnect creates the tunnel to the remote machine (via guacd) -func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { +func DemoDoConnect(request *http.Request) (_ guac.Tunnel, err error) { config := guac.NewGuacamoleConfiguration() var query url.Values @@ -93,12 +145,10 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { } config.Protocol = query.Get("scheme") - config.Parameters = map[string]string{} for k, v := range query { config.Parameters[k] = v[0] } - var err error if query.Get("width") != "" { config.OptimalScreenHeight, err = strconv.Atoi(query.Get("width")) if err != nil || config.OptimalScreenHeight == 0 { @@ -117,6 +167,10 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { logrus.Debug("Connecting to guacd") addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:4822") + if err != nil { + logrus.Errorln("error while resolving 127.0.0.1") + return nil, err + } conn, err := net.DialTCP("tcp", nil, addr) if err != nil { @@ -130,11 +184,23 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { if request.URL.Query().Get("uuid") != "" { config.ConnectionID = request.URL.Query().Get("uuid") } + logrus.Debugf("Starting handshake with %#v", config) err = stream.Handshake(config) if err != nil { return nil, err } logrus.Debug("Socket configured") + return guac.NewSimpleTunnel(stream), nil } + +// DemoDoConnectWithIntercept showcases a use for intercepting streams +func DemoDoConnectWithIntercept(r *http.Request) (guac.Tunnel, error) { + t, err := DemoDoConnect(r) + if err != nil { + return nil, err + } + + return guac.NewUserTunnel(t), nil +} diff --git a/config.go b/config.go index c8af4fd..85f0ac7 100644 --- a/config.go +++ b/config.go @@ -5,22 +5,22 @@ type Config struct { // ConnectionID is used to reconnect to an existing session, otherwise leave blank for a new session. ConnectionID string // Protocol is the protocol of the connection from guacd to the remote (rdp, ssh, etc). - Protocol string + Protocol string // Parameters are used to configure protocol specific options like sla for rdp or terminal color schemes. - Parameters map[string]string + Parameters map[string]string // OptimalScreenWidth is the desired width of the screen - OptimalScreenWidth int + OptimalScreenWidth int // OptimalScreenHeight is the desired height of the screen OptimalScreenHeight int // OptimalResolution is the desired resolution of the screen - OptimalResolution int + OptimalResolution int // AudioMimetypes is an array of the supported audio types - AudioMimetypes []string + AudioMimetypes []string // VideoMimetypes is an array of the supported video types - VideoMimetypes []string + VideoMimetypes []string // ImageMimetypes is an array of the supported image types - ImageMimetypes []string + ImageMimetypes []string } // NewGuacamoleConfiguration returns a Config with sane defaults diff --git a/doc.go b/doc.go index f0dfdc0..4b89bff 100644 --- a/doc.go +++ b/doc.go @@ -1,4 +1,4 @@ /* Package guac implements a HTTP client and a WebSocket client that connects to an Apache Guacamole server. - */ +*/ package guac diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..d45e001 --- /dev/null +++ b/filter.go @@ -0,0 +1,5 @@ +package guac + +type Filter interface { + Filter(*Instruction) (*Instruction, error) +} diff --git a/go.mod b/go.mod index 20ab153..f52475b 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ replace github.com/Sirupsen/logrus v1.4.2 => github.com/sirupsen/logrus v1.4.2 require ( github.com/google/uuid v1.1.1 + github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.1 github.com/sirupsen/logrus v1.4.2 ) diff --git a/go.sum b/go.sum index ab4a565..713fc13 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,19 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/input_intercepting_filter.go b/input_intercepting_filter.go new file mode 100644 index 0000000..c7cbc01 --- /dev/null +++ b/input_intercepting_filter.go @@ -0,0 +1,168 @@ +package guac + +import ( + "encoding/base64" + "errors" + "io" + "strconv" + "sync" + + "github.com/sirupsen/logrus" +) + +var ( + _ Filter = (*InputInterceptingFilter)(nil) + _ Filter = (*OutputInterceptingFilter)(nil) +) + +// Whether this OutputInterceptingFilter should respond to received +// blobs with "ack" messages on behalf of the client. If false, blobs will +// still be handled by this filter, but empty blobs will be sent to the +// client, forcing the client to respond on its own. +var acknowledgeBlobs bool = true + +type InputInterceptingFilter struct { + tunnel Tunnel + l sync.Mutex + + streams map[string]*InterceptedInputStream +} + +func NewInputInterceptingFilter(tunnel Tunnel) *InputInterceptingFilter { + streams := make(map[string]*InterceptedInputStream) + return &InputInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *InputInterceptingFilter) sendInstruction(instr *Instruction) (err error) { + w := t.tunnel.AcquireWriter() + defer t.tunnel.ReleaseWriter() + + if _, err = w.Write(instr.Byte()); err != nil { + logrus.WithError(err).Error("failed to write instruction") + return err + } + + return nil +} + +func (t *InputInterceptingFilter) getInterceptedInputStream(index string) *InterceptedInputStream { + t.l.Lock() + defer t.l.Unlock() + + return t.streams[index] +} + +func (t *InputInterceptingFilter) closeInterceptedStream(index string, err error) { + t.l.Lock() + defer t.l.Unlock() + + if t.streams[index] != nil { + t.streams[index].done <- err + } + delete(t.streams, index) +} + +func (t *InputInterceptingFilter) CloseAll() { + for k := range t.streams { + t.closeInterceptedStream(k, nil) + } +} + +func (t *InputInterceptingFilter) InterceptStream(index string, stream io.Reader) <-chan error { + signal := make(chan error, 1) + + interceptedInputStream := NewInterceptedInputStream(index, stream, signal) + + t.l.Lock() + t.streams[index] = interceptedInputStream + t.l.Unlock() + + t.handleInterceptedStream(interceptedInputStream) + + return signal +} + +func (t *InputInterceptingFilter) sendBlob(index string, blob []byte) { + data := base64.StdEncoding.Strict().EncodeToString(blob) + if err := t.sendInstruction(NewInstruction("blob", index, data)); err != nil { + logrus.Errorf("failed to send base64 blob to stream index %s %v", index, err) + + t.sendEnd(index) + t.closeInterceptedStream(index, err) + } +} + +func (t *InputInterceptingFilter) sendEnd(index string) { + if err := t.sendInstruction(NewInstruction("end", index)); err != nil { + logrus.Errorf("failed to send end to stream index %s %v", index, err) + } +} + +func (t *InputInterceptingFilter) readNextBlob(stream *InterceptedInputStream) { + blob := make([]byte, 4096) + + if n, err := io.ReadFull(stream.Stream, blob); err != nil { + if n > 0 { + logrus.Debug("there are still some bytes") + t.sendBlob(stream.Index, blob[:n]) + return + } + + if !errors.Is(err, io.EOF) { + logrus.WithError(err).Errorf("could not read from stream %s", stream.Index) + } else { + err = nil + } + + t.sendEnd(stream.Index) + t.closeInterceptedStream(stream.Index, err) + + return + } + + t.sendBlob(stream.Index, blob) +} + +func (t *InputInterceptingFilter) handleACK(instruction *Instruction) { + if len(instruction.Args) < 3 { + return + } + + index := instruction.Args[0] + + stream := t.getInterceptedInputStream(index) + if stream == nil { + logrus.Warning("empty intercepted input stream on ACK") + return + } + + status := instruction.Args[2] + code := Success + + if status != "0" { + codeInt, err := strconv.Atoi(status) + code = FromGuacamoleStatusCode(codeInt) + + if err != nil { + logrus.Error("failed to translate status code") + code = ServerError + } + + t.closeInterceptedStream(stream.Index, ErrServer.NewError(code.String(), instruction.Args[1])) + return + } + + t.readNextBlob(stream) +} + +func (t *InputInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + if instruction.Opcode == "ack" { + t.handleACK(instruction) + } + + return instruction, nil +} + +func (t *InputInterceptingFilter) handleInterceptedStream(stream *InterceptedInputStream) { + t.readNextBlob(stream) +} diff --git a/input_intercepting_filter_test.go b/input_intercepting_filter_test.go new file mode 100644 index 0000000..9ab8094 --- /dev/null +++ b/input_intercepting_filter_test.go @@ -0,0 +1,63 @@ +package guac + +import ( + "bytes" + "encoding/base64" + "fmt" + "testing" + "time" +) + +func TestInputInterceptingFilter(t *testing.T) { + t.Run("OK", func(t *testing.T) { + conn := &fakeConn{ + ToRead: []byte(""), + } + + f := NewInputInterceptingFilter(NewUserTunnel( + NewSimpleTunnel( + NewStream(conn, time.Minute), + ), + )) + + firstBlob := bytes.Repeat([]byte("A"), 4096) + secondBlob := bytes.Repeat([]byte("B"), 100) + + toInject := append(firstBlob, secondBlob...) + + // Hijack stream 1 and inject some data that will need to end up on the wire + finished := f.InterceptStream("1", bytes.NewReader([]byte(toInject))) + + encoded := base64.StdEncoding.EncodeToString(firstBlob) + + if got, want := string(conn.ToWrite), fmt.Sprintf("4.blob,1.1,%d.%s;", len(encoded), encoded); got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + + // Simulate an ACK from guacd + f.Filter(NewInstruction("ack", "1", "", "0")) + + encoded = base64.StdEncoding.EncodeToString(secondBlob) + + if got, want := string(conn.ToWrite), fmt.Sprintf("4.blob,1.1,%d.%s;", len(encoded), encoded); got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + + // Simulate another ACK from guacd, the packet should have been + // fragmented in two: one which contains the first 4096 bytes + // base64 encoded, and the second which contains the remaining + // 100 bytes base64 encoded. + f.Filter(NewInstruction("ack", "1", "", "0")) + + // There shouldn't be any pending read, so finished should have + // completed by now, if not that's an error and this test should + // timeout. + if err := <-finished; err != nil { + t.Fatal(err) + } + + if got, want := string(conn.ToWrite), "3.end,1.1;"; got != want { + t.Fatalf("On the wire: %v, want=%v", got, want) + } + }) +} diff --git a/instruction.go b/instruction.go index d8529dd..96a986e 100644 --- a/instruction.go +++ b/instruction.go @@ -111,3 +111,44 @@ func ReadOne(stream *Stream) (instruction *Instruction, err error) { return Parse(instructionBuffer) } + +// FilteredInstructionReader is a struct that provides a filtered +// InstructionReader and handles instructions through a filter. +type FilteredInstructionReader struct { + InstructionReader + + Filter +} + +func NewFilteredInstructionReader(r InstructionReader, filter Filter) InstructionReader { + return &FilteredInstructionReader{r, filter} +} + +func (r *FilteredInstructionReader) ReadSome() ([]byte, error) { + for { + unfilteredInstruction, err := readOne(r.InstructionReader) + if err != nil { + return nil, err + } + + filteredInstruction, err := r.Filter.Filter(unfilteredInstruction) + if err != nil { + return nil, err + } + + // Continue reading and filtering until no instructions are dropped + if filteredInstruction != nil { + return filteredInstruction.Byte(), err + } + } +} + +// readOne takes an instruction from the stream and parses it into an Instruction +func readOne(r InstructionReader) (instruction *Instruction, err error) { + instructionBuffer, err := r.ReadSome() + if err != nil { + return + } + + return Parse(instructionBuffer) +} diff --git a/instruction_test.go b/instruction_test.go index f8629c4..14a0951 100644 --- a/instruction_test.go +++ b/instruction_test.go @@ -68,3 +68,43 @@ func TestReadOne(t *testing.T) { t.Error("Unexpected", ins.String()) } } + +var _ Filter = (*dropFilter)(nil) + +// dropFilter drops all the instructions defined in drop +type dropFilter struct { + Drop []string +} + +func (f *dropFilter) Filter(i *Instruction) (*Instruction, error) { + for _, v := range f.Drop { + if v == i.Opcode { + return nil, nil + } + } + + return i, nil +} + +func TestFilteredInstructionReader(t *testing.T) { + t.Run("OK", func(t *testing.T) { + f := &dropFilter{Drop: []string{"select"}} + + s := NewStream(&fakeConn{ + ToRead: []byte(`6.select,2.hi,5.hello,4.asdf;6.teston,2.hi,5.hello,4.asdf;`), + }, time.Minute) + + fi := NewFilteredInstructionReader(s, f) + + result, err := fi.ReadSome() + if err != nil { + t.Fatal(err) + } + + if got, want := string(result), "6.teston,2.hi,5.hello,4.asdf;"; got != want { + t.Fatalf("Result=%v, want %v", got, want) + } + }) + + // Won't test malformed input, because that's already tested on Stream +} diff --git a/intercepted_stream.go b/intercepted_stream.go new file mode 100644 index 0000000..bb98001 --- /dev/null +++ b/intercepted_stream.go @@ -0,0 +1,25 @@ +package guac + +import "io" + +type InterceptedOutputStream struct { + Index string + Stream io.Writer + + done chan<- error +} + +func NewInterceptedOutputStream(index string, stream io.Writer, signal chan<- error) *InterceptedOutputStream { + return &InterceptedOutputStream{Index: index, Stream: stream, done: signal} +} + +type InterceptedInputStream struct { + Index string + Stream io.Reader + + done chan<- error +} + +func NewInterceptedInputStream(index string, stream io.Reader, signal chan<- error) *InterceptedInputStream { + return &InterceptedInputStream{Index: index, Stream: stream, done: signal} +} diff --git a/output_intercepting_filter.go b/output_intercepting_filter.go new file mode 100644 index 0000000..14b5c38 --- /dev/null +++ b/output_intercepting_filter.go @@ -0,0 +1,171 @@ +package guac + +import ( + "encoding/base64" + "errors" + "io" + "sync" + + "github.com/sirupsen/logrus" +) + +type OutputInterceptingFilter struct { + l sync.Mutex + tunnel Tunnel + streams map[string]*InterceptedOutputStream +} + +func NewOutputInterceptingFilter(tunnel Tunnel) *OutputInterceptingFilter { + streams := make(map[string]*InterceptedOutputStream) + return &OutputInterceptingFilter{tunnel: tunnel, streams: streams} +} + +func (t *OutputInterceptingFilter) sendInstruction(instr *Instruction) error { + w := t.tunnel.AcquireWriter() + if _, err := w.Write(instr.Byte()); err != nil { + logrus.WithError(err).Error("failed to send instruction") + return err + } + + t.tunnel.ReleaseWriter() + return nil +} + +func (t *OutputInterceptingFilter) getInterceptedStream(idx string) *InterceptedOutputStream { + t.l.Lock() + defer t.l.Unlock() + + return t.streams[idx] +} + +func (t *OutputInterceptingFilter) sendACK(index string, message string, status Status) { + if status != Success { + t.closeInterceptedStream(index, ErrServer.NewError(status.String(), message)) + } + + if err := t.sendInstruction(NewInstruction("ack", index, message, status.String())); err != nil { + logrus.Errorf("unable to send ACK for stream %s", index) + } +} + +func (t *OutputInterceptingFilter) InterceptStream(index string, outStream io.Writer) <-chan error { + signal := make(chan error, 1) + + if t.tunnel == nil { + defer func() { + signal <- errors.New("invalid tunnel") + }() + + return signal + } + + interceptedOutputStream := NewInterceptedOutputStream(index, outStream, signal) + + t.l.Lock() + t.streams[index] = interceptedOutputStream + t.l.Unlock() + + t.handleInterceptedStream(interceptedOutputStream) + + return signal +} + +func (t *OutputInterceptingFilter) handleBlob(instruction *Instruction) (*Instruction, error) { + // Verify all required arguments are present + args := instruction.Args + if len(args) < 2 { + return instruction, nil + } + + // Pull associated stream + streamIndex := args[0] + + outputInterceptedStream := t.getInterceptedStream(streamIndex) + if outputInterceptedStream == nil { + return instruction, nil + } + + // Decode blob + data := args[1] + + blob, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + + if outputInterceptedStream.Stream == nil { + return nil, errors.New("stream in outputInterceptedStream is nil") + } + + if _, err := outputInterceptedStream.Stream.Write(blob); err != nil { + // User closed the connection, no need to panic, + // Just don't track it anymore and close the stream. + t.closeInterceptedStream(streamIndex, nil) + + logrus.WithError(err).Info("failed to write to intercepted stream: maybe user has closed the connection?") + + // Exit cleanly, we don't need to make the server quit listening. + return nil, nil + } + + // Force client to respond with their own "ack" if we need to + // confirm that they are not falling behind with respect to the + // graphical session + if !acknowledgeBlobs { + acknowledgeBlobs = true + return NewInstruction("blob", streamIndex, ""), nil + } + + t.sendACK(streamIndex, "OK", Success) + + // Instruction was handled purely internally + return nil, nil +} + +func (t *OutputInterceptingFilter) handleEnd(instruction *Instruction) { + args := instruction.Args + if len(args) < 1 { + return + } + + t.closeInterceptedStream(args[0], nil) +} + +func (t *OutputInterceptingFilter) handleSync(instruction *Instruction) { + acknowledgeBlobs = false +} + +func (t *OutputInterceptingFilter) Filter(instruction *Instruction) (*Instruction, error) { + switch instruction.Opcode { + case "blob": + return t.handleBlob(instruction) + case "end": + t.handleEnd(instruction) + case "sync": + t.handleSync(instruction) + } + return instruction, nil +} + +func (t *OutputInterceptingFilter) handleInterceptedStream(stream *InterceptedOutputStream) { + t.sendACK(stream.Index, "OK", Success) +} + +func (t *OutputInterceptingFilter) closeInterceptedStream(index string, err error) *InterceptedOutputStream { + interceptedStream := t.streams[index] + if interceptedStream != nil { + interceptedStream.done <- err + } + + t.l.Lock() + delete(t.streams, index) + t.l.Unlock() + + return interceptedStream +} + +func (t *OutputInterceptingFilter) CloseAllInterceptedStreams() { + for k := range t.streams { + t.closeInterceptedStream(k, nil) + } +} diff --git a/server.go b/server.go index 69221a1..0a018cd 100644 --- a/server.go +++ b/server.go @@ -2,10 +2,12 @@ package guac import ( "fmt" - logger "github.com/sirupsen/logrus" "io" "net/http" "strings" + + "github.com/gorilla/mux" + logger "github.com/sirupsen/logrus" ) const ( @@ -60,21 +62,27 @@ func (s *Server) sendError(response http.ResponseWriter, guacStatus Status, mess } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - err := s.handleTunnelRequestCore(w, r) - if err == nil { - return - } - guacErr := err.(*ErrGuac) - switch guacErr.Kind { - case ErrClient: - logger.Warn("HTTP tunnel request rejected: ", err.Error()) - s.sendError(w, guacErr.Status, err.Error()) - default: - logger.Error("HTTP tunnel request failed: ", err.Error()) - logger.Debug("Internal error in HTTP tunnel.", err) - s.sendError(w, guacErr.Status, "Internal server error.") - } - return + m := mux.NewRouter() + + m.HandleFunc("/debug", func(rw http.ResponseWriter, r *http.Request) { + w.Write([]byte("CIAOOOOOOO")) + }) + + m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + err := s.handleTunnelRequestCore(w, r) + guacErr := err.(*ErrGuac) + switch guacErr.Kind { + case ErrClient: + logger.Warn("HTTP tunnel request rejected: ", err.Error()) + s.sendError(w, guacErr.Status, err.Error()) + default: + logger.Error("HTTP tunnel request failed: ", err.Error()) + logger.Debug("Internal error in HTTP tunnel.", err) + s.sendError(w, guacErr.Status, "Internal server error.") + } + }) + + m.ServeHTTP(w, r) } func (s *Server) handleTunnelRequestCore(response http.ResponseWriter, request *http.Request) (err error) { diff --git a/stream.go b/stream.go index 3f31622..592d1a9 100644 --- a/stream.go +++ b/stream.go @@ -9,10 +9,12 @@ import ( ) const ( - SocketTimeout = 15 * time.Second + SocketTimeout = 120 * time.Second MaxGuacMessage = 8192 // TODO is this bytes or runes? ) +var _ InstructionReader = (*Stream)(nil) + // Stream wraps the connection to Guacamole providing timeouts and reading // a single instruction at a time (since returning partial instructions // would be an error) diff --git a/stream_test.go b/stream_test.go index f5f8bc5..ec07033 100644 --- a/stream_test.go +++ b/stream_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net" + "sync" "testing" "time" ) @@ -98,12 +99,16 @@ func TestInstructionReader_Flush(t *testing.T) { } type fakeConn struct { + lock sync.Mutex ToRead []byte + ToWrite []byte HasRead bool Closed bool } func (f *fakeConn) Read(b []byte) (n int, err error) { + f.lock.Lock() + defer f.lock.Unlock() if f.HasRead { return 0, io.EOF } else { @@ -113,7 +118,11 @@ func (f *fakeConn) Read(b []byte) (n int, err error) { } func (f *fakeConn) Write(b []byte) (n int, err error) { - return 0, nil + f.lock.Lock() + defer f.lock.Unlock() + + f.ToWrite = b + return len(b), nil } func (f *fakeConn) Close() error { diff --git a/tunnel.go b/tunnel.go index 5b37ca5..4e16e31 100644 --- a/tunnel.go +++ b/tunnel.go @@ -2,10 +2,17 @@ package guac import ( "fmt" - "github.com/google/uuid" "io" + + "github.com/google/uuid" ) +// Ensure SimpleTunnel implements the Tunnel interface +var _ Tunnel = (*SimpleTunnel)(nil) + +// Ensure InstructionReader implements the InstructionReader interface +var _ InstructionReader = (*FilteredGuacamoleReader)(nil) + // The Guacamole protocol instruction Opcode reserved for arbitrary // internal use by tunnel implementations. The value of this Opcode is // guaranteed to be the empty string (""). Tunnel implementations may use @@ -116,3 +123,41 @@ func (t *SimpleTunnel) Close() (err error) { func (t *SimpleTunnel) GetUUID() string { return t.uuid.String() } + +type FilteredGuacamoleReader struct { + InstructionReader + filter Filter +} + +func NewFilteredGuacamoleReader(reader InstructionReader, filter Filter) *FilteredGuacamoleReader { + return &FilteredGuacamoleReader{reader, filter} +} + +// ReadOne takes an instruction from the stream and parses it into an Instruction +func (r *FilteredGuacamoleReader) ReadOne() (instruction *Instruction, err error) { + instructionBuffer, err := r.InstructionReader.ReadSome() + if err != nil { + return + } + + return Parse(instructionBuffer) +} + +func (r *FilteredGuacamoleReader) ReadSome() ([]byte, error) { + for { + unfilteredInstruction, err := r.ReadOne() + if err != nil { + return nil, err + } + + filteredInstruction, err := r.filter.Filter(unfilteredInstruction) + if err != nil { + return nil, err + } + + // Continue reading and filtering until no instructions are dropped + if filteredInstruction != nil { + return filteredInstruction.Byte(), err + } + } +} diff --git a/tunnel_map.go b/tunnel_map.go index 306a5fd..97df076 100644 --- a/tunnel_map.go +++ b/tunnel_map.go @@ -62,7 +62,7 @@ type TunnelMap struct { tunnelTimeout time.Duration // Map of all tunnels that are using HTTP, indexed by tunnel UUID. - tunnelMap map[string]*LastAccessedTunnel + tunnelMap map[string]*LastAccessedTunnel } // NewTunnelMap creates a new TunnelMap and starts the scheduled job with the default timeout. diff --git a/user_tunnel.go b/user_tunnel.go new file mode 100644 index 0000000..47ab66a --- /dev/null +++ b/user_tunnel.go @@ -0,0 +1,47 @@ +package guac + +import "io" + +// Ensure UserTunnel implements Tunnel +var _ Tunnel = (*UserTunnel)(nil) + +type UserTunnel struct { + Tunnel + + outputFilter *OutputInterceptingFilter + inputFilter *InputInterceptingFilter +} + +func NewUserTunnel(tunnel Tunnel) *UserTunnel { + tun := &UserTunnel{Tunnel: tunnel} + + tun.inputFilter, tun.outputFilter = NewInputInterceptingFilter(tun), NewOutputInterceptingFilter(tun) + + return tun +} + +// InterceptOutputStream intercepts an output stream, i.e. when downloading +// a file you provide a http.ResponseWriter and InterceptOutputStream will +// pipe the stream numbers through it. +func (t *UserTunnel) InterceptOutputStream(id string, stream io.Writer) error { + return <-t.outputFilter.InterceptStream(id, stream) +} + +// InterceptInputStream intercepts an input stream, i.e. when uploading a file. +// For example you can pass a http.Request.Body() to inject a file in a Guacamole stream. +func (t *UserTunnel) InterceptInputStream(id string, stream io.Reader) error { + return <-t.inputFilter.InterceptStream(id, stream) +} + +// AcquireReader of UserTunnel wraps the original AcquireReader +// but it filters the instructions before handing them to the +// caller. +func (t *UserTunnel) AcquireReader() InstructionReader { + reader := t.Tunnel.AcquireReader() + + // Filter both for input and output streams + return NewFilteredInstructionReader( + NewFilteredInstructionReader(reader, t.inputFilter), + t.outputFilter, + ) +} diff --git a/ws_server.go b/ws_server.go index 2321a9d..c96f5f0 100644 --- a/ws_server.go +++ b/ws_server.go @@ -2,7 +2,6 @@ package guac import ( "bytes" - "io" "net/http" "github.com/gorilla/websocket" @@ -22,7 +21,7 @@ type WebsocketServer struct { OnDisconnect func(string, *http.Request, Tunnel) // OnConnectWs is an optional callback called when a websocket connects. - OnConnectWs func(string, *websocket.Conn, *http.Request) + OnConnectWs func(string, *websocket.Conn, *http.Request, Tunnel) // OnDisconnectWs is an optional callback called when the websocket disconnects. OnDisconnectWs func(string, *websocket.Conn, *http.Request, Tunnel) } @@ -92,12 +91,9 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.OnConnect(id, r) } if s.OnConnectWs != nil { - s.OnConnectWs(id, ws, r) + s.OnConnectWs(id, ws, r, tunnel) } - writer := tunnel.AcquireWriter() - reader := tunnel.AcquireReader() - if s.OnDisconnect != nil { defer s.OnDisconnect(id, r, tunnel) } @@ -105,11 +101,8 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer s.OnDisconnectWs(id, ws, r, tunnel) } - defer tunnel.ReleaseWriter() - defer tunnel.ReleaseReader() - - go wsToGuacd(ws, writer) - guacdToWs(ws, reader) + go wsToGuacd(ws, tunnel) + guacdToWs(ws, tunnel) } // MessageReader wraps a websocket connection and only permits Reading @@ -118,7 +111,7 @@ type MessageReader interface { ReadMessage() (int, []byte, error) } -func wsToGuacd(ws MessageReader, guacd io.Writer) { +func wsToGuacd(ws MessageReader, tunnel Tunnel) { for { _, data, err := ws.ReadMessage() if err != nil { @@ -130,11 +123,14 @@ func wsToGuacd(ws MessageReader, guacd io.Writer) { // messages starting with the InternalDataOpcode are never sent to guacd continue } - + guacd := tunnel.AcquireWriter() if _, err = guacd.Write(data); err != nil { logrus.Traceln("Failed writing to guacd", err) + tunnel.ReleaseWriter() + return } + tunnel.ReleaseWriter() } } @@ -144,11 +140,20 @@ type MessageWriter interface { WriteMessage(int, []byte) error } -func guacdToWs(ws MessageWriter, guacd InstructionReader) { +func guacdToWs(ws MessageWriter, tunnel Tunnel) { buf := bytes.NewBuffer(make([]byte, 0, MaxGuacMessage*2)) + uuid := NewInstruction(InternalDataOpcode, tunnel.ConnectionID()) + if err := ws.WriteMessage(1, uuid.Byte()); err != nil { + logrus.Traceln("Failed to send uuid to ws", err) + return + } + for { + guacd := tunnel.AcquireReader() ins, err := guacd.ReadSome() + tunnel.ReleaseReader() + if err != nil { logrus.Traceln("Error reading from guacd", err) return @@ -164,8 +169,12 @@ func guacdToWs(ws MessageWriter, guacd InstructionReader) { return } + guacd = tunnel.AcquireReader() + avail := guacd.Available() + tunnel.ReleaseReader() + // if the buffer has more data in it or we've reached the max buffer size, send the data and reset - if !guacd.Available() || buf.Len() >= MaxGuacMessage { + if !avail || buf.Len() >= MaxGuacMessage { if err = ws.WriteMessage(1, buf.Bytes()); err != nil { if err == websocket.ErrCloseSent { return diff --git a/ws_server_test.go b/ws_server_test.go index f51e067..1ef2b90 100644 --- a/ws_server_test.go +++ b/ws_server_test.go @@ -23,7 +23,7 @@ func TestWebsocketServer_guacdToWs(t *testing.T) { conn := &fakeConn{ ToRead: expected, } - guac := NewStream(conn, time.Minute) + guac := NewSimpleTunnel(NewStream(conn, time.Minute)) guacdToWs(msgWriter, guac)