Skip to content

Commit

Permalink
fix: dont panic
Browse files Browse the repository at this point in the history
  • Loading branch information
ryshoooo committed Oct 20, 2024
1 parent 4a5292f commit a12bb39
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
36 changes: 18 additions & 18 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"github.com/sirupsen/logrus"
)

func HandleErrorResponse(w http.ResponseWriter, statusCode int, message string) {
func HandleErrorResponse(logger *logrus.Logger, w http.ResponseWriter, statusCode int, message string) {
w.WriteHeader(statusCode)
err := json.NewEncoder(w).Encode(&ApiError{Detail: message})
if err != nil {
panic(err)
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("Failed to encode error response: %s", err)
}
}

Expand All @@ -26,15 +26,15 @@ func CreateNewConnection(logger *logrus.Logger, usernameLifetime int) http.Handl
err := json.NewDecoder(r.Body).Decode(data)
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusBadRequest, "Failed to parse request")
HandleErrorResponse(logger, w, http.StatusBadRequest, "Failed to parse request")
return
}
id := uuid.New().String()
foodme.GlobalState.AddConnection(id, data.AccessToken, data.RefreshToken, usernameLifetime)
w.WriteHeader(http.StatusOK)
err = json.NewEncoder(w).Encode(&NewConnectionResponse{Username: id})
if err != nil {
panic(err)
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
}
}
}
Expand All @@ -49,47 +49,47 @@ func ApplyPermissionAgent(logger *logrus.Logger, conf *foodme.Configuration, htt
err := json.NewDecoder(r.Body).Decode(data)
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse request: %s", err))
HandleErrorResponse(logger, w, http.StatusBadRequest, fmt.Sprintf("Failed to parse request: %s", err))
return
}

if data.Username == "" {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] No username provided", r)
HandleErrorResponse(w, http.StatusBadRequest, "No username provided")
HandleErrorResponse(logger, w, http.StatusBadRequest, "No username provided")
return
}

if data.SQL == "" {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] No SQL provided", r)
HandleErrorResponse(w, http.StatusBadRequest, "No SQL provided")
HandleErrorResponse(logger, w, http.StatusBadRequest, "No SQL provided")
return
}

// Validate data and state
if !conf.OIDCEnabled {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] OIDC is disabled", r)
HandleErrorResponse(w, http.StatusFailedDependency, "OIDC is disabled")
HandleErrorResponse(logger, w, http.StatusFailedDependency, "OIDC is disabled")
return
}

if !conf.PermissionAgentEnabled {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] Permission agent is disabled", r)
HandleErrorResponse(w, http.StatusFailedDependency, "Permission agent is disabled")
HandleErrorResponse(logger, w, http.StatusFailedDependency, "Permission agent is disabled")
return
}

at, rt := foodme.GlobalState.GetTokens(data.Username)
if at == "" || rt == "" {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] No tokens found for user %s", r, data.Username)
HandleErrorResponse(w, http.StatusNotFound, "No tokens found for user "+data.Username)
HandleErrorResponse(logger, w, http.StatusNotFound, "No tokens found for user "+data.Username)
return
}

// Get userinfo
cspec, ok := conf.OIDCDatabaseClients[data.Database]
if !ok && !conf.OIDCDatabaseFallBackToBaseClient {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] No client found for database %s", r, data.Database)
HandleErrorResponse(w, http.StatusNotFound, "No client found for database "+data.Database)
HandleErrorResponse(logger, w, http.StatusNotFound, "No client found for database "+data.Database)
return
} else if !ok {
cspec = &foodme.OIDCDatabaseClientSpec{ClientID: conf.OIDCClientID, ClientSecret: conf.OIDCClientSecret}
Expand All @@ -100,52 +100,52 @@ func ApplyPermissionAgent(logger *logrus.Logger, conf *foodme.Configuration, htt
err = oidcClient.RefreshAccessToken()
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusUnauthorized, "Failed to refresh access token: "+err.Error())
HandleErrorResponse(logger, w, http.StatusUnauthorized, "Failed to refresh access token: "+err.Error())
return
}
}

uinfo, err := oidcClient.GetUserInfo()
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusUnauthorized, "Failed to get user info: "+err.Error())
HandleErrorResponse(logger, w, http.StatusUnauthorized, "Failed to get user info: "+err.Error())
return
}

// Establish sql handler
agent := foodme.NewPermissionAgent(conf, httpClient)
if agent == nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] Failed to create permission agent", r)
HandleErrorResponse(w, http.StatusInternalServerError, "Failed to create permission agent")
HandleErrorResponse(logger, w, http.StatusInternalServerError, "Failed to create permission agent")
return
}

sqlHandler, err := foodme.NewSQLHandler(conf.DestinationDatabaseType, logger, agent)
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusInternalServerError, "Failed to create SQL handler: "+err.Error())
HandleErrorResponse(logger, w, http.StatusInternalServerError, "Failed to create SQL handler: "+err.Error())
return
}

err = sqlHandler.SetDDL(uinfo)
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusInternalServerError, "Failed to set DDL: "+err.Error())
HandleErrorResponse(logger, w, http.StatusInternalServerError, "Failed to set DDL: "+err.Error())
return
}

newSQL, err := sqlHandler.Handle(data.SQL, uinfo)
if err != nil {
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
HandleErrorResponse(w, http.StatusInternalServerError, "Failed to handle SQL: "+err.Error())
HandleErrorResponse(logger, w, http.StatusInternalServerError, "Failed to handle SQL: "+err.Error())
return
}

// Respond
w.WriteHeader(http.StatusOK)
err = json.NewEncoder(w).Encode(&PermissionApplyResponse{SQL: data.SQL, NewSQL: newSQL})
if err != nil {
panic(err)
logger.WithFields(logrus.Fields{"component": "api"}).Errorf("[%p] %s", r, err)
}
}
}
36 changes: 33 additions & 3 deletions api/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ type MockHeaders struct {
}

type MockResponseWriter struct {
buffer *MockBuffer
headers *MockHeaders
buffer *MockBuffer
headers *MockHeaders
failWrite bool
}

type MockBody struct {
Expand All @@ -48,6 +49,9 @@ func (m MockResponseWriter) Header() http.Header {
}

func (m MockResponseWriter) Write(data []byte) (int, error) {
if m.failWrite {
return 0, fmt.Errorf("write failure")
}
m.buffer.buffer = append(m.buffer.buffer, data...)
return len(data), nil
}
Expand Down Expand Up @@ -90,9 +94,13 @@ func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) {

func TestHandleErrorResponse(t *testing.T) {
w := MockResponseWriter{buffer: &MockBuffer{buffer: []byte{}}, headers: &MockHeaders{headers: []int{}}}
HandleErrorResponse(w, 500, "message")
logger := logrus.StandardLogger()
HandleErrorResponse(logger, w, 500, "message")
assert.DeepEqual(t, w.headers.headers, []int{500})
assert.DeepEqual(t, w.buffer.buffer, []byte("{\"detail\":\"message\"}\n"))

w.failWrite = true
HandleErrorResponse(logger, w, 500, "message")
}

func TestCreateNewConnectionFail(t *testing.T) {
Expand Down Expand Up @@ -122,6 +130,11 @@ func TestCreateNewConnectionOK(t *testing.T) {
assert.Equal(t, at, "a")
assert.Equal(t, rt, "r")
foodme.GlobalState.DeleteConnection(data.Username)

w = MockResponseWriter{buffer: &MockBuffer{buffer: []byte{}}, headers: &MockHeaders{headers: []int{}}, failWrite: true}
body = &MockBody{Body: "{\"access_token\":\"a\",\"refresh_token\":\"r\"}"}
r = &http.Request{Body: body}
handler(w, r)
}

func TestApplyPermissionAgent(t *testing.T) {
Expand Down Expand Up @@ -306,4 +319,21 @@ func TestApplyPermissionAgent(t *testing.T) {
err = json.Unmarshal(w.buffer.buffer, &respData)
assert.NilError(t, err)
assert.DeepEqual(t, respData, map[string]interface{}{"sql": "select * from pets p", "new_sql": "SELECT * FROM pets AS p WHERE ((p.owners >= 23))"})

mockHttpClient = &MockHttpClient{
DoSucceed: true,
Response: []string{
"{\"access_token\":\"access\"}",
"{\"preferred_username\":\"test_user\"}",
"{}",
"{}",
"{}",
"{\"result\":{\"queries\":[[{\"terms\":[{\"type\":\"number\",\"value\":23},{\"type\":\"ref\",\"value\":[{\"type\":\"var\",\"value\":\"gte\"}]},{\"type\":\"ref\",\"value\":[{\"type\":\"var\",\"value\":\"data\"},{\"type\":\"string\",\"value\":\"tables\"},{\"type\":\"string\",\"value\":\"pets\"},{\"type\":\"string\",\"value\":\"owners\"}]}]}]]}}",
},
StatusCode: 200,
}
handler = ApplyPermissionAgent(log, conf, mockHttpClient)
w = MockResponseWriter{buffer: &MockBuffer{buffer: []byte{}}, headers: &MockHeaders{headers: []int{}}, failWrite: true}
r = &http.Request{Body: &MockBody{Body: "{\"username\":\"test\", \"sql\":\"select * from pets p\"}"}}
handler(w, r)
}
6 changes: 2 additions & 4 deletions internal/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,8 @@ func (h *PostgresHandler) proxyDownstream() {
}
} else {
_, err := io.Copy(h.client, h.upstream)
if err != nil {
if err != io.EOF {
h.Logger.Errorf("Error copying from upstream to client: %v", err)
}
if err != nil && err != io.EOF {
h.Logger.Errorf("Error copying from upstream to client: %v", err)
}
}
}
Expand Down

0 comments on commit a12bb39

Please sign in to comment.