diff --git a/pagerduty/pagerduty.go b/pagerduty/pagerduty.go index 1ac9ba8..453eb08 100644 --- a/pagerduty/pagerduty.go +++ b/pagerduty/pagerduty.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "log" "net/http" "net/url" @@ -240,11 +239,17 @@ func (c *Client) newRequestDoOptionsContext(ctx context.Context, method, url str } func (c *Client) do(req *http.Request, v interface{}) (*Response, error) { + sLogger := newSecureLogger() + sLogger.LogReq(req) + resp, err := c.client.Do(req) if err != nil { return nil, err } - bodyBytes, err := ioutil.ReadAll(resp.Body) + + sLogger.LogRes(resp) + + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return nil, err } diff --git a/pagerduty/secure_logger.go b/pagerduty/secure_logger.go new file mode 100644 index 0000000..ce94c3d --- /dev/null +++ b/pagerduty/secure_logger.go @@ -0,0 +1,135 @@ +package pagerduty + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" +) + +const ( + secureLogRequestHeading = `[SECURE] PagerDuty API Request Details: +---[ REQUEST ]---------------------------------------` + secureLogResponseHeading = `[SECURE] PagerDuty API Response Details: +---[ RESPONSE ]--------------------------------------` + secureLogBottomDelimiter = `-----------------------------------------------------` + obscuredLogTag = `` +) + +type secureLogger struct { + logger *log.Logger + headersContent string + bodyContent string + logsContent string + canLog bool +} + +func (l *secureLogger) handleHeadersLogsContent(h http.Header) { + l.headersContent = "" + headers := make(http.Header) + for k, v := range h { + headers[k] = v + } + + if _, ok := headers["Authorization"]; ok { + authHeader := headers["Authorization"][0] + last4AuthChars := authHeader + if len(authHeader) > 4 { + last4AuthChars = authHeader[len(authHeader)-4:] + } + headers["Authorization"] = []string{fmt.Sprintf("%s%s", obscuredLogTag, last4AuthChars)} + } + + for k, v := range headers { + h := fmt.Sprintf("%s: %s", k, strings.Join(v, ";")) + l.headersContent = fmt.Sprintf("%s%s\n", l.headersContent, h) + } +} + +func (l *secureLogger) handleBodyLogsContent(body io.ReadCloser) io.ReadCloser { + l.bodyContent = "" + if body != nil { + bodyBytes, err := io.ReadAll(body) + if err != nil { + log.Printf("[ERROR] Error reading body: %v\n", err) + return body + } + + var jsonObj map[string]interface{} + err = json.Unmarshal(bodyBytes, &jsonObj) + if err != nil { + l.bodyContent = fmt.Sprintf("%s\n", string(bodyBytes)) + } else { + prettyBody, err := json.MarshalIndent(jsonObj, "", " ") + if err != nil { + log.Printf("[ERROR] Error pretty-printing body: %v\n", err) + } else { + l.bodyContent = fmt.Sprintf("%s\n", prettyBody) + } + } + + body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + return body +} + +func (l *secureLogger) putTogetherLogsContent(logsContent *string, heading string) { + content := *logsContent + content = fmt.Sprintf("%s\n%s\n%s\n", heading, content, l.headersContent) + if l.bodyContent != "" { + content = fmt.Sprintf("%s%s\n", content, l.bodyContent) + } + + content = fmt.Sprintf("%s%s", content, secureLogBottomDelimiter) + *logsContent = content +} + +func (l *secureLogger) LogReq(req *http.Request) { + if !l.canLog { + return + } + + logsContent := fmt.Sprintf("%s %s %s", req.Method, req.URL.Path, req.Proto) + l.handleHeadersLogsContent(req.Header) + req.Body = l.handleBodyLogsContent(req.Body) + l.putTogetherLogsContent(&logsContent, secureLogRequestHeading) + + l.logger.Print(logsContent) +} + +func (l *secureLogger) LogRes(res *http.Response) { + if !l.canLog { + return + } + + logsContent := fmt.Sprintf("%s %d %s", res.Proto, res.StatusCode, res.Status) + l.handleHeadersLogsContent(res.Header) + res.Body = l.handleBodyLogsContent(res.Body) + l.putTogetherLogsContent(&logsContent, secureLogResponseHeading) + + l.logger.Print(logsContent) +} + +func (l *secureLogger) SetCanLog(flag bool) { + l.canLog = flag +} + +func newSecureLogger() *secureLogger { + pdLogFlag := os.Getenv("TF_LOG_PROVIDER_PAGERDUTY") + pdLogFlag = strings.ToUpper(pdLogFlag) + tfLogFlag := os.Getenv("TF_LOG") + tfLogFlag = strings.ToUpper(tfLogFlag) + + secLogger := secureLogger{ + logger: log.Default(), + canLog: tfLogFlag == "INFO" && pdLogFlag == "SECURE", + } + secLogger.logger.SetFlags(log.Flags() &^ (log.Ldate | log.Ltime)) + + return &secLogger +} diff --git a/pagerduty/secure_logger_test.go b/pagerduty/secure_logger_test.go new file mode 100644 index 0000000..bf07f8d --- /dev/null +++ b/pagerduty/secure_logger_test.go @@ -0,0 +1,100 @@ +package pagerduty + +import ( + "bytes" + "io" + "log" + "net/http" + "strings" + "testing" +) + +func TestSecureLoggerHandleHeadersLogsContent(t *testing.T) { + l := newSecureLogger() + l.SetCanLog(true) + headers := http.Header{ + "Authorization": []string{"Bearer secretApiKey"}, + "Content-Type": []string{"application/json"}, + } + l.handleHeadersLogsContent(headers) + + if !strings.Contains(l.headersContent, "iKey") { + t.Errorf("Authorization header not properly obscured: got %s", l.headersContent) + } +} + +func TestSecureLoggerHandleBodyLogsContent_JSON(t *testing.T) { + l := newSecureLogger() + l.SetCanLog(true) + body := io.NopCloser(bytes.NewReader([]byte(`{"key": "value"}`))) + _ = l.handleBodyLogsContent(body) + + if !strings.Contains(l.bodyContent, ` "key": "value"`) { + t.Errorf("JSON body not properly formatted: got %s", l.bodyContent) + } +} + +func TestSecureLoggerHandleBodyLogsContent_NonJSON(t *testing.T) { + l := newSecureLogger() + l.SetCanLog(true) + body := io.NopCloser(bytes.NewReader([]byte(`non-json content`))) + _ = l.handleBodyLogsContent(body) + + if l.bodyContent != "non-json content\n" { + t.Errorf("Non-JSON body not handled correctly: got %s", l.bodyContent) + } +} + +func TestSecureLoggerCanLog(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + + l := newSecureLogger() + l.SetCanLog(false) + req, _ := http.NewRequest("GET", "/abilities", nil) + l.LogReq(req) + + if buf.String() != "" { + t.Errorf("Logger should not have log: got %s", buf.String()) + } + + l.SetCanLog(true) + l.LogReq(req) + if !strings.Contains(buf.String(), "/abilities") { + t.Errorf("Request not logged correctly: got %s", buf.String()) + } +} + +func TestSecureLoggerLogReq(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + + l := newSecureLogger() + l.SetCanLog(true) + req, _ := http.NewRequest("GET", "/abilities", nil) + l.LogReq(req) + + if !strings.Contains(buf.String(), "/abilities") { + t.Errorf("Request not logged correctly: got %s", buf.String()) + } +} + +func TestSecureLoggerLogRes(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + + l := newSecureLogger() + l.SetCanLog(true) + res := &http.Response{ + Proto: "HTTP/1.1", + StatusCode: 200, + Status: "OK", + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{"key": "value"}`))), + } + l.LogRes(res) + + if !strings.Contains(buf.String(), "HTTP/1.1 200 OK") { + t.Errorf("Response not logged correctly: got %s", buf.String()) + } +}