Skip to content

Commit

Permalink
Move more code to common lib
Browse files Browse the repository at this point in the history
  • Loading branch information
Antti Paloposki committed Jun 9, 2022
1 parent 820f43d commit 97974dc
Show file tree
Hide file tree
Showing 13 changed files with 934 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# golang output binary directory
bin
vendor

# test results
gotest-coverage.out
gotest-report.out
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
OPERATOR_NAME := go-common

ensure:
GO111MODULE=on go mod tidy -compat=1.17
GO111MODULE=on go mod vendor

build:
rm -f bin/$(OPERATOR_NAME)
GO111MODULE=on go build -mod vendor -v -o bin/$(OPERATOR_NAME) .

test:
GO111MODULE=on go test -failfast -mod vendor ./*.go -v -covermode atomic -coverprofile=gotest-coverage.out $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1)
git diff --exit-code go.mod go.sum

26 changes: 26 additions & 0 deletions arrays.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ package common

import "strings"

// Unique returns unique array items
func Unique(values []string) []string {
keys := make(map[string]bool)
list := []string{}
for _, value := range values {
if _, ok := keys[value]; !ok {
keys[value] = true
list = append(list, value)
}
}
return list
}

// EqualStringArrays compares equality of two string arrays
func EqualStringArrays(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}

// ContainsInteger returns true if integer is found from array
func ContainsInteger(array []int, value int) bool {
for _, currentValue := range array {
Expand Down
63 changes: 63 additions & 0 deletions crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package common

import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
)

// Base64decode ...
func Base64decode(v string) (string, error) {
data, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return "", fmt.Errorf("base64 decode failed: %w", err)
}
return string(data), nil
}

func createHash(key string) string {
hasher := md5.New() //nolint:gosec // TODO https://atlas.elisa.fi/jira/browse/DEV-3364
hasher.Write([]byte(key))
return hex.EncodeToString(hasher.Sum(nil))
}

// Encrypt the secret
// source https://www.thepolyglotdeveloper.com/2018/02/encrypt-decrypt-data-golang-application-crypto-packages/
func Encrypt(data []byte, passphrase string) []byte {
block, _ := aes.NewCipher([]byte(createHash(passphrase)))
gcm, err := cipher.NewGCM(block)
if err != nil {
panic(err.Error())
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
panic(err.Error())
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext
}

// Decrypt the secret
func Decrypt(data []byte, passphrase string) []byte {
key := []byte(createHash(passphrase))
block, err := aes.NewCipher(key)
if err != nil {
panic(err.Error())
}
gcm, err := cipher.NewGCM(block)
if err != nil {
panic(err.Error())
}
nonceSize := gcm.NonceSize()
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
panic(err.Error())
}
return plaintext
}
18 changes: 18 additions & 0 deletions crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package common

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBase64decode(t *testing.T) {
encoded := "U1VDQ0VTUw=="
decoded, err := Base64decode(encoded)
assert.Nil(t, err)
assert.Equal(t, "SUCCESS", decoded)

failing := "^"
_, err = Base64decode(failing)
assert.NotNil(t, err)
}
134 changes: 134 additions & 0 deletions csrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package common

import (
"fmt"
"net/http"
"net/url"

"github.com/gin-gonic/gin"
)

const (
badToken = "CSRF token missing or incorrect." //nolint:gosec
tokenMissing = "CSRF cookie not set." //nolint:gosec
noReferer = "Referer checking failed - no Referer."
malformedReferer = "Referer checking failed - Referer is malformed."
insecureReferer = "Referer checking failed - Referer is insecure while host is secure."
// CsrfTokenKey ...
CsrfTokenKey = "csrftoken"
// Xcsrf ...
Xcsrf = "X-CSRF-Token"
// Authorization ...
Authorization = "Authorization"
)

var ignoreMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}

func getHeader(c *gin.Context) string {
return c.Request.Header.Get(Xcsrf)
}

func isAPIUser(c *gin.Context) bool {
return c.Request.Header.Get(Authorization) != ""
}

func isJWTMachineUser(c *gin.Context) bool {
session, err := c.Request.Cookie("session")
if err == nil && session.Value != "" {
return true
}
return false
}

func getCookie(c *gin.Context) string {
session, err := c.Request.Cookie(CsrfTokenKey)
if err == nil && session.Value != "" {
return session.Value
}
return ""
}

// CSRF is middleware for handling CSRF protection in gin
func CSRF() gin.HandlerFunc {
return func(c *gin.Context) {
// allow machineusers
if isAPIUser(c) || isJWTMachineUser(c) {
c.Next()
return
}

csrfToken := getCookie(c)

// Assume that anything not defined as 'safe' by RFC7231 needs protection
if ContainsString(ignoreMethods, c.Request.Method) {
// set cookie in response if not found
if csrfToken == "" {
val, err := RandomToken()
if err != nil {
c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer})
c.Abort()
return
}
http.SetCookie(c.Writer, &http.Cookie{
Name: CsrfTokenKey,
Value: val,
Path: "/",
Domain: c.Request.URL.Host,
HttpOnly: false,
Secure: IsHTTPS(c.Request),
MaxAge: 12 * 60 * 60,
SameSite: http.SameSiteLaxMode,
})
}
// Set the Vary: Cookie header to protect clients from caching the response.
c.Header("Vary", "Cookie")
c.Next()
return
}

if IsHTTPS(c.Request) {
referer := c.Request.Header.Get("Referer")
if referer == "" {
c.JSON(403, ErrorResponse{Code: 403, Message: noReferer})
c.Abort()
return
}

parsedURL, err := url.Parse(referer)
if err != nil {
c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer})
c.Abort()
return
}

if parsedURL.Scheme != "https" {
c.JSON(403, ErrorResponse{Code: 403, Message: insecureReferer})
c.Abort()
return
}

if parsedURL.Host != c.Request.Host {
msg := fmt.Sprintf("Referer checking failed - %s does not match any trusted origins.", parsedURL.Host)
c.JSON(403, ErrorResponse{Code: 403, Message: msg})
c.Abort()
return
}
}

requestCSRFToken := getHeader(c)
if csrfToken == "" {
c.JSON(403, ErrorResponse{Code: 403, Message: tokenMissing})
c.Abort()
return
}

if requestCSRFToken != csrfToken {
c.JSON(403, ErrorResponse{Code: 403, Message: badToken})
c.Abort()
return
}

// process request
c.Next()
}
}
Loading

0 comments on commit 97974dc

Please sign in to comment.