From 6a72e8044ba1c2dafb732822a192856c54708348 Mon Sep 17 00:00:00 2001 From: Lajos Koszti Date: Thu, 23 May 2024 08:24:16 +0200 Subject: [PATCH] refactor: refactor entry and entry key handling The remaining reads and expiration are now handled at the entry key level instead of the entry level. The tests have been updated to reflect these changes. Additionally, some unnecessary code has been removed and several minor improvements have been made. --- api/api_test.go | 13 +- internal/api/createentry.go | 4 +- internal/api/generateentrykey.go | 8 +- internal/api/generateentrykey_test.go | 2 +- internal/models/entry.go | 41 ++-- internal/models/entry_test.go | 21 +- internal/models/entrykey.go | 10 +- internal/models/entrykey_test.go | 15 +- internal/models/migrate/entry.go | 18 ++ internal/models/mock.go | 6 +- internal/services/entrykeymanager.go | 32 +-- internal/services/entrykeymanager_test.go | 230 +++++++++++----------- internal/services/entrymanager.go | 35 ++-- internal/services/entrymanager_test.go | 97 +++------ internal/services/entryvalidation.go | 8 - internal/services/interfaces.go | 6 +- internal/services/mocks.go | 6 +- internal/views/entrycreate.go | 2 + internal/views/entrydelete.go | 3 + internal/views/entrykeycreate.go | 2 + 20 files changed, 238 insertions(+), 321 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 74462a9..8a8b492 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -640,7 +640,6 @@ func TestCreateEntryWithMaxReads(t *testing.T) { NewSecretHandler(NewHandlerConfig(db)).ServeHTTP(w, req) resp := w.Result() - model := &models.EntryModel{} savedUUID := resp.Header.Get("x-entry-uuid") @@ -652,7 +651,8 @@ func TestCreateEntryWithMaxReads(t *testing.T) { if err != nil { t.Fatal(err) } - entry, err := model.ReadEntryMeta(ctx, tx, savedUUID) + model := &models.EntryKeyModel{} + entries, err := model.Get(ctx, tx, savedUUID) if err != nil { if err := tx.Rollback(); err != nil { @@ -665,8 +665,13 @@ func TestCreateEntryWithMaxReads(t *testing.T) { t.Errorf("commit failed: %v", err) } - if entry.RemainingReads != 2 { - t.Fatalf("expected max reads to be: %d, actual: %d", 2, entry.RemainingReads) + if len(entries) != 1 { + t.Fatalf("expected to get entry key %d, got %d", 1, len(entries)) + } + + remainingReads := entries[0].RemainingReads.Int16 + if remainingReads != 2 { + t.Fatalf("expected max reads to be: %d, actual: %d", 2, remainingReads) } } diff --git a/internal/api/createentry.go b/internal/api/createentry.go index 9f02366..29c76d0 100644 --- a/internal/api/createentry.go +++ b/internal/api/createentry.go @@ -77,9 +77,7 @@ func (c CreateHandler) handle(w http.ResponseWriter, r *http.Request) error { // Handle handles http request to create secret func (c CreateHandler) Handle(w http.ResponseWriter, r *http.Request) { - err := c.handle(w, r) - - if err != nil { + if err := c.handle(w, r); err != nil { log.Println("create error", err) c.view.RenderError(w, r, err) } diff --git a/internal/api/generateentrykey.go b/internal/api/generateentrykey.go index 018967d..2b84eeb 100644 --- a/internal/api/generateentrykey.go +++ b/internal/api/generateentrykey.go @@ -2,7 +2,9 @@ package api import ( "context" + "fmt" "net/http" + "time" "github.com/Ajnasz/sekret.link/internal/key" "github.com/Ajnasz/sekret.link/internal/parsers" @@ -16,7 +18,7 @@ type GenerateEntryKeyView interface { } type GenerateEntryKeyManager interface { - GenerateEntryKey(ctx context.Context, UUID string, k key.Key) (*services.EntryKeyData, error) + GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error) } type GenerateEntryKeyHandler struct { @@ -44,10 +46,12 @@ func (g GenerateEntryKeyHandler) handle(w http.ResponseWriter, r *http.Request) return err } + fmt.Println(request) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key) + entry, err := g.entryManager.GenerateEntryKey(ctx, request.UUID, request.Key, request.Expiration, request.MaxReads) if err != nil { return err } diff --git a/internal/api/generateentrykey_test.go b/internal/api/generateentrykey_test.go index 716b062..2290333 100644 --- a/internal/api/generateentrykey_test.go +++ b/internal/api/generateentrykey_test.go @@ -30,7 +30,7 @@ type MockGenerateEntryKeyManager struct { mock.Mock } -func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, UUID string, k key.Key) (*services.EntryKeyData, error) { +func (m *MockGenerateEntryKeyManager) GenerateEntryKey(ctx context.Context, UUID string, k key.Key, expire time.Duration, maxReads int) (*services.EntryKeyData, error) { args := m.Called(ctx, UUID, k) return args.Get(0).(*services.EntryKeyData), args.Error(2) } diff --git a/internal/models/entry.go b/internal/models/entry.go index fa2e3f5..dd1fe27 100644 --- a/internal/models/entry.go +++ b/internal/models/entry.go @@ -18,22 +18,18 @@ var ErrInvalidKey = errors.New("invalid key") var ErrCreateEntry = errors.New("failed to create entry") type EntryMeta struct { - UUID string - RemainingReads int - DeleteKey string - Created time.Time - Accessed sql.NullTime - Expire time.Time - ContentType string + UUID string + DeleteKey string + Created time.Time + Accessed sql.NullTime + ContentType string } // uuid uuid PRIMARY KEY, // data BYTEA, -// remaining_reads SMALLINT DEFAULT 1, // delete_key CHAR(256) NOT NULL, // created TIMESTAMPTZ, // accessed TIMESTAMPTZ, -// expire TIMESTAMPTZ type Entry struct { EntryMeta Data []byte @@ -51,14 +47,14 @@ func (e *EntryModel) getDeleteKey() (string, error) { } // CreateEntry creates a new entry into the database -func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, contenType string, data []byte, remainingReads int, expire time.Duration) (*EntryMeta, error) { +func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, contenType string, data []byte) (*EntryMeta, error) { deleteKey, err := e.getDeleteKey() if err != nil { return nil, errors.Join(err, ErrCreateEntry) } now := time.Now() - res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, expire, remaining_reads, delete_key, content_type) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING uuid, delete_key;`, uuid, data, now, now.Add(expire), remainingReads, deleteKey, contenType) + res, err := tx.ExecContext(ctx, `INSERT INTO entries (uuid, data, created, delete_key, content_type) VALUES ($1, $2, $3, $4, $5) RETURNING uuid, delete_key;`, uuid, data, now, deleteKey, contenType) if err != nil { return nil, errors.Join(err, ErrCreateEntry) @@ -74,25 +70,23 @@ func (e *EntryModel) CreateEntry(ctx context.Context, tx *sql.Tx, uuid string, c } return &EntryMeta{ - UUID: uuid, - RemainingReads: remainingReads, - DeleteKey: deleteKey, - Created: now, - Expire: now.Add(expire), + UUID: uuid, + DeleteKey: deleteKey, + Created: now, }, err } func (e *EntryModel) Use(ctx context.Context, tx *sql.Tx, uuid string) error { - _, err := tx.ExecContext(ctx, "UPDATE entries SET accessed = NOW(), remaining_reads = remaining_reads - 1 WHERE uuid = $1 AND remaining_reads > 0", uuid) + _, err := tx.ExecContext(ctx, "UPDATE entries SET accessed = NOW() WHERE uuid = $1", uuid) return err } // ReadEntry reads a entry from the database // and updates the read count func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*Entry, error) { - row := tx.QueryRow("SELECT uuid, data, remaining_reads, delete_key, created, accessed, expire, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid) + row := tx.QueryRow("SELECT uuid, data, delete_key, created, accessed, content_type FROM entries WHERE uuid=$1 LIMIT 1", uuid) var s Entry - err := row.Scan(&s.UUID, &s.Data, &s.RemainingReads, &s.DeleteKey, &s.Created, &s.Accessed, &s.Expire, &s.ContentType) + err := row.Scan(&s.UUID, &s.Data, &s.DeleteKey, &s.Created, &s.Accessed, &s.ContentType) if err != nil { if err == sql.ErrNoRows { return nil, ErrEntryNotFound @@ -104,9 +98,9 @@ func (e *EntryModel) ReadEntry(ctx context.Context, tx *sql.Tx, uuid string) (*E } func (e *EntryModel) ReadEntryMeta(ctx context.Context, tx *sql.Tx, uuid string) (*EntryMeta, error) { - row := tx.QueryRow("SELECT created, accessed, expire, remaining_reads, delete_key, content_type FROM entries WHERE uuid=$1 AND remaining_reads > 0 LIMIT 1", uuid) + row := tx.QueryRow("SELECT created, accessed, delete_key, content_type FROM entries WHERE uuid=$1 LIMIT 1", uuid) var s EntryMeta - err := row.Scan(&s.Created, &s.Accessed, &s.Expire, &s.RemainingReads, &s.DeleteKey, &s.ContentType) + err := row.Scan(&s.Created, &s.Accessed, &s.DeleteKey, &s.ContentType) if err != nil { if err == sql.ErrNoRows { return nil, ErrEntryNotFound @@ -157,7 +151,8 @@ func (e *EntryModel) DeleteEntry(ctx context.Context, tx *sql.Tx, uuid string, d } func (e *EntryModel) DeleteExpired(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW()") + // TODO join with entry_keys table and delete if no living entry found + // _, err := tx.ExecContext(ctx, "DELETE FROM entries WHERE expire < NOW()") - return err + return nil } diff --git a/internal/models/entry_test.go b/internal/models/entry_test.go index c9b32ca..5ed8d87 100644 --- a/internal/models/entry_test.go +++ b/internal/models/entry_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "testing" - "time" "github.com/Ajnasz/sekret.link/internal/test/durable" "github.com/google/uuid" @@ -26,20 +25,14 @@ func Test_EntryModel_CreateEntry(t *testing.T) { uid := uuid.New().String() data := []byte("test data") - remainingReads := 2 - expire := time.Hour * 24 model := &EntryModel{} - meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire) + meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data) if err != nil { t.Fatal(err) } - if meta.RemainingReads != 2 { - t.Errorf("expected %d got %d", remainingReads, meta.RemainingReads) - } - if meta.UUID != uid { t.Errorf("expected %s got %s", uid, meta.UUID) } @@ -52,10 +45,6 @@ func Test_EntryModel_CreateEntry(t *testing.T) { t.Errorf("expected created to be set") } - if meta.Expire.IsZero() { - t.Errorf("expected expire to be set") - } - if meta.Accessed.Valid { t.Errorf("expected accessed not to be set") } @@ -81,12 +70,10 @@ func Test_EntryModel_Use(t *testing.T) { uid := uuid.New().String() data := []byte("test data") - remainingReads := 2 - expire := time.Hour * 24 model := &EntryModel{} - meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data, remainingReads, expire) + meta, err := model.CreateEntry(ctx, tx, uid, "text/plain", data) if err != nil { t.Fatal(err) } @@ -100,10 +87,6 @@ func Test_EntryModel_Use(t *testing.T) { t.Fatal(errors.Join(err, errors.New("failed to read entry"))) } - if entry.RemainingReads != 1 { - t.Errorf("expected %d got %d", 0, entry.RemainingReads) - } - if !entry.Accessed.Valid { t.Errorf("expected accessed to be set") } diff --git a/internal/models/entrykey.go b/internal/models/entrykey.go index 15dff71..7f0125b 100644 --- a/internal/models/entrykey.go +++ b/internal/models/entrykey.go @@ -18,13 +18,13 @@ type EntryKey struct { type EntryKeyModel struct{} -func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*EntryKey, error) { +func (e *EntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*EntryKey, error) { now := time.Now() res := tx.QueryRowContext(ctx, ` - INSERT INTO entry_key (uuid, entry_uuid, encrypted_key, key_hash, created) - VALUES (gen_random_uuid(), $1, $2, $3, $4) RETURNING uuid, created; - `, entryUUID, encryptedKey, hash, now) + INSERT INTO entry_key (uuid, entry_uuid, encrypted_key, key_hash, created, remaining_reads, expire) + VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6) RETURNING uuid, created; + `, entryUUID, encryptedKey, hash, now, remainingReads, expire) var uid string var created time.Time @@ -50,7 +50,7 @@ func (e *EntryKeyModel) Get(ctx context.Context, tx *sql.Tx, entryUUID string) ( SELECT uuid, entry_uuid, encrypted_key, key_hash, created, expire, remaining_reads FROM entry_key WHERE entry_uuid = $1 - AND (expire IS NULL OR expire > NOW()); + ; `, entryUUID) if err != nil { diff --git a/internal/models/entrykey_test.go b/internal/models/entrykey_test.go index c0a18b7..9a75fb9 100644 --- a/internal/models/entrykey_test.go +++ b/internal/models/entrykey_test.go @@ -33,7 +33,7 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error) entryModel := &EntryModel{} - _, err := entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600) + _, err := entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data")) if err != nil { return "", "", err @@ -41,7 +41,7 @@ func createTestEntryKey(ctx context.Context, tx *sql.Tx) (string, string, error) model := &EntryKeyModel{} - entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx")) + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hash entrykey use tx"), time.Now().Add(time.Hour), 2) if err != nil { return "", "", err @@ -66,14 +66,14 @@ func Test_EntryKeyModel_Create(t *testing.T) { uid := uuid.New().String() entryModel := &EntryModel{} - _, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600) + _, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data")) if err != nil { t.Fatal(err) } model := &EntryKeyModel{} - entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke")) + entryKey, err := model.Create(ctx, tx, uid, []byte("test"), []byte("hashke"), time.Now().Add(time.Hour), 2) if err != nil { if err := tx.Rollback(); err != nil { @@ -124,7 +124,7 @@ func Test_EntryKeyModel_Get(t *testing.T) { uid := uuid.New().String() entryModel := &EntryModel{} - _, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data"), 2, 3600) + _, err = entryModel.CreateEntry(ctx, tx, uid, "text/plain", []byte("test data")) if err != nil { if err := tx.Rollback(); err != nil { t.Error(err) @@ -135,7 +135,7 @@ func Test_EntryKeyModel_Get(t *testing.T) { model := &EntryKeyModel{} for i := 0; i < 10; i++ { - _, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i))) + _, err = model.Create(ctx, tx, uid, []byte("test"), []byte(fmt.Sprintf("hashke %d", i)), time.Now().Add(time.Hour), 2) if err != nil { if err := tx.Rollback(); err != nil { @@ -156,6 +156,7 @@ func Test_EntryKeyModel_Get(t *testing.T) { } entryKeys, err := model.Get(ctx, tx, uid) + fmt.Println("ENTRY KEYS", entryKeys) if err != nil { if err := tx.Rollback(); err != nil { @@ -169,7 +170,7 @@ func Test_EntryKeyModel_Get(t *testing.T) { } if len(entryKeys) != 10 { - t.Fatalf("expected 1 got %d", len(entryKeys)) + t.Fatalf("expected 10 got %d", len(entryKeys)) } if entryKeys[0].EntryUUID != uid { diff --git a/internal/models/migrate/entry.go b/internal/models/migrate/entry.go index 970c4e1..2e9c5e1 100644 --- a/internal/models/migrate/entry.go +++ b/internal/models/migrate/entry.go @@ -45,6 +45,10 @@ func (e *EntryMigration) Alter(ctx context.Context, tx *sql.Tx) error { return err } + if err := e.dropKeyFields(ctx, tx); err != nil { + return err + } + return nil } @@ -123,3 +127,17 @@ func (e *EntryMigration) addContentType(ctx context.Context, tx *sql.Tx) error { return nil } + +func (e *EntryMigration) dropKeyFields(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, "ALTER TABLE entries DROP COLUMN IF EXISTS remaining_reads;") + if err != nil { + return fmt.Errorf("failed to drop remaining_reads column: %w", err) + } + + _, err = tx.ExecContext(ctx, "ALTER TABLE entries DROP COLUMN IF EXISTS expire;") + if err != nil { + return fmt.Errorf("failed to drop expire column: %w", err) + } + + return nil +} diff --git a/internal/models/mock.go b/internal/models/mock.go index fda1853..3fec41b 100644 --- a/internal/models/mock.go +++ b/internal/models/mock.go @@ -3,7 +3,6 @@ package models import ( "context" "database/sql" - "time" "github.com/stretchr/testify/mock" ) @@ -18,9 +17,8 @@ func (m *MockEntryModel) CreateEntry( UUID string, contentType string, data []byte, - remainingReads int, - expire time.Duration) (*EntryMeta, error) { - args := m.Called(ctx, tx, UUID, data, remainingReads, expire) +) (*EntryMeta, error) { + args := m.Called(ctx, tx, UUID, data) return args.Get(0).(*EntryMeta), args.Error(1) } diff --git a/internal/services/entrykeymanager.go b/internal/services/entrykeymanager.go index fde668e..2475e00 100644 --- a/internal/services/entrykeymanager.go +++ b/internal/services/entrykeymanager.go @@ -18,7 +18,7 @@ var ErrEntryCreateFailed = errors.New("entry create failed") var ErrGetDEKFailed = errors.New("get DEK failed") type EntryKeyModel interface { - Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*models.EntryKey, error) + Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*models.EntryKey, error) Get(ctx context.Context, tx *sql.Tx, entryUUID string) ([]models.EntryKey, error) Delete(ctx context.Context, tx *sql.Tx, uuid string) error SetExpire(ctx context.Context, tx *sql.Tx, uuid string, expire time.Time) error @@ -42,7 +42,7 @@ func NewEntryKeyManager(db *sql.DB, model EntryKeyModel, hasher hasher.Hasher, e } } -func (e *EntryKeyManager) Create(ctx context.Context, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (e *EntryKeyManager) Create(ctx context.Context, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { tx, err := e.db.BeginTx(ctx, nil) if err != nil { @@ -87,7 +87,7 @@ func modelEntryKeyToEntryKey(m *models.EntryKey) *EntryKey { } } -func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { k, err := key.NewGeneratedKey() if err != nil { @@ -100,33 +100,11 @@ func (e *EntryKeyManager) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUI } hash := e.hasher.Hash(dek.Get()) - entryKey, err := e.model.Create(ctx, tx, entryUUID, encryptedKey, hash) + entryKey, err := e.model.Create(ctx, tx, entryUUID, encryptedKey, hash, expire, maxRead) if err != nil { return nil, nil, errors.Join(ErrEntryCreateFailed, err) } - if expire != nil { - err := e.model.SetExpire(ctx, tx, entryKey.UUID, *expire) - if err != nil { - return nil, nil, errors.Join(ErrEntryCreateFailed, err) - } - entryKey.Expire = sql.NullTime{ - Time: *expire, - Valid: true, - } - } - - if maxRead != nil { - err := e.model.SetMaxReads(ctx, tx, entryKey.UUID, *maxRead) - if err != nil { - return nil, nil, errors.Join(ErrEntryCreateFailed, err) - } - - entryKey.RemainingReads = sql.NullInt16{ - Int16: int16(*maxRead), - Valid: true, - } - } return modelEntryKeyToEntryKey(entryKey), *k, nil } @@ -224,7 +202,7 @@ func (e *EntryKeyManager) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID st } // GenerateEncryptionKey creates a new key for the entry -func (e EntryKeyManager) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (e EntryKeyManager) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { tx, err := e.db.BeginTx(ctx, nil) if err != nil { return nil, nil, err diff --git a/internal/services/entrykeymanager_test.go b/internal/services/entrykeymanager_test.go index e09605b..9971945 100644 --- a/internal/services/entrykeymanager_test.go +++ b/internal/services/entrykeymanager_test.go @@ -17,8 +17,8 @@ type MockEntryKeyModel struct { mock.Mock } -func (m *MockEntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte) (*models.EntryKey, error) { - args := m.Called(ctx, tx, entryUUID, encryptedKey, hash) +func (m *MockEntryKeyModel) Create(ctx context.Context, tx *sql.Tx, entryUUID string, encryptedKey []byte, hash []byte, expire time.Time, remainingReads int) (*models.EntryKey, error) { + args := m.Called(ctx, tx, entryUUID, encryptedKey, hash, expire, remainingReads) return args.Get(0).(*models.EntryKey), args.Error(1) } @@ -95,24 +95,24 @@ func TestEntryKeyManager_Create(t *testing.T) { encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) hasher.On("Hash", dek.Get()).Return(hash) - model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ + model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash, expire, maxRead).Return(&models.EntryKey{ UUID: "test-uuid", EntryUUID: entryUUID, EncryptedKey: encryptedKey, Created: time.Now(), - Expire: sql.NullTime{Time: time.Now(), Valid: false}, - RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + Expire: sql.NullTime{Time: expire, Valid: false}, + RemainingReads: sql.NullInt16{Int16: int16(maxRead), Valid: false}, }, nil) - model.On("SetExpire", ctx, mock.Anything, "test-uuid", expire).Return(nil) - model.On("SetMaxReads", ctx, mock.Anything, "test-uuid", maxRead).Return(nil) + // model.On("SetExpire", ctx, mock.Anything, "test-uuid", expire).Return(nil) + // model.On("SetMaxReads", ctx, mock.Anything, "test-uuid", maxRead).Return(nil) crypto := func(key key.Key) Encrypter { return encrypter } manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.Create(ctx, entryUUID, *dek, &expire, &maxRead) + entryKey, key, err := manager.Create(ctx, entryUUID, *dek, expire, maxRead) model.AssertExpectations(t) encrypter.AssertExpectations(t) @@ -127,107 +127,107 @@ func TestEntryKeyManager_Create(t *testing.T) { assert.NotEmpty(t, key.Get()) } -func TestEntryKeyManager_Create_NoExpire(t *testing.T) { - db, sqlMock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() - - sqlMock.ExpectBegin() - sqlMock.ExpectCommit() - - ctx := context.Background() - model := &MockEntryKeyModel{} - hasher := &MockHasher{} - encrypter := &EncrypterMock{} - dek, err := key.NewGeneratedKey() - assert.NoError(t, err) - entryUUID := "test-entry-uuid" - encryptedKey := []byte("test-encrypted-key") - hash := []byte("test-hash") - - hasher.On("Hash", dek.Get()).Return(hash) - encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) - model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ - UUID: "test-uuid", - EntryUUID: entryUUID, - EncryptedKey: encryptedKey, - KeyHash: hash, - Created: time.Now(), - Expire: sql.NullTime{Time: time.Now(), Valid: false}, - RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, - }, nil) - - crypto := func(key key.Key) Encrypter { - return encrypter - } - manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.Create(ctx, entryUUID, *dek, nil, nil) - - hasher.AssertExpectations(t) - encrypter.AssertExpectations(t) - model.AssertExpectations(t) - if sqlMock.ExpectationsWereMet() != nil { - t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) - } - assert.NoError(t, err) - assert.Equal(t, "test-uuid", entryKey.UUID) - // assert.False(nil, entryKey.Expire) - assert.NotEmpty(t, key.Get()) -} - -func TestEntryKeyManager_Create_NoMaxRead(t *testing.T) { - db, sqlMock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() - - sqlMock.ExpectBegin() - sqlMock.ExpectCommit() - - ctx := context.Background() - model := &MockEntryKeyModel{} - hasher := &MockHasher{} - encrypter := &EncrypterMock{} - entryUUID := "test-entry-uuid" - dek := []byte("test-dek") - encryptedKey := []byte("test-encrypted-key") - hash := []byte("test-hash") - - hasher.On("Hash", dek).Return(hash) - encrypter.On("Encrypt", dek).Return(encryptedKey, nil) - model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ - UUID: "test-uuid", - EntryUUID: entryUUID, - EncryptedKey: encryptedKey, - KeyHash: hash, - Created: time.Now(), - Expire: sql.NullTime{Time: time.Now(), Valid: false}, - RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, - }, nil) - - crypto := func(key key.Key) Encrypter { - return encrypter - } - - manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.Create(ctx, entryUUID, dek, nil, nil) - - model.AssertExpectations(t) - hasher.AssertExpectations(t) - encrypter.AssertExpectations(t) - if sqlMock.ExpectationsWereMet() != nil { - t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) - } - assert.NoError(t, err) - assert.Equal(t, "test-uuid", entryKey.UUID) - // key.Get should not return an empty string - assert.NotEmpty(t, key.Get()) -} +// func TestEntryKeyManager_Create_NoExpire(t *testing.T) { +// db, sqlMock, err := sqlmock.New() +// if err != nil { +// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) +// } +// +// defer db.Close() +// +// sqlMock.ExpectBegin() +// sqlMock.ExpectCommit() +// +// ctx := context.Background() +// model := &MockEntryKeyModel{} +// hasher := &MockHasher{} +// encrypter := &EncrypterMock{} +// dek, err := key.NewGeneratedKey() +// assert.NoError(t, err) +// entryUUID := "test-entry-uuid" +// encryptedKey := []byte("test-encrypted-key") +// hash := []byte("test-hash") +// +// hasher.On("Hash", dek.Get()).Return(hash) +// encrypter.On("Encrypt", dek.Get()).Return(encryptedKey, nil) +// model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ +// UUID: "test-uuid", +// EntryUUID: entryUUID, +// EncryptedKey: encryptedKey, +// KeyHash: hash, +// Created: time.Now(), +// Expire: sql.NullTime{Time: time.Now(), Valid: false}, +// RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, +// }, nil) +// +// crypto := func(key key.Key) Encrypter { +// return encrypter +// } +// manager := NewEntryKeyManager(db, model, hasher, crypto) +// entryKey, key, err := manager.Create(ctx, entryUUID, *dek, 0, 0) +// +// hasher.AssertExpectations(t) +// encrypter.AssertExpectations(t) +// model.AssertExpectations(t) +// if sqlMock.ExpectationsWereMet() != nil { +// t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) +// } +// assert.NoError(t, err) +// assert.Equal(t, "test-uuid", entryKey.UUID) +// // assert.False(nil, entryKey.Expire) +// assert.NotEmpty(t, key.Get()) +// } + +// func TestEntryKeyManager_Create_NoMaxRead(t *testing.T) { +// db, sqlMock, err := sqlmock.New() +// if err != nil { +// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) +// } +// +// defer db.Close() +// +// sqlMock.ExpectBegin() +// sqlMock.ExpectCommit() +// +// ctx := context.Background() +// model := &MockEntryKeyModel{} +// hasher := &MockHasher{} +// encrypter := &EncrypterMock{} +// entryUUID := "test-entry-uuid" +// dek := []byte("test-dek") +// encryptedKey := []byte("test-encrypted-key") +// hash := []byte("test-hash") +// +// hasher.On("Hash", dek).Return(hash) +// encrypter.On("Encrypt", dek).Return(encryptedKey, nil) +// model.On("Create", ctx, mock.Anything, entryUUID, encryptedKey, hash).Return(&models.EntryKey{ +// UUID: "test-uuid", +// EntryUUID: entryUUID, +// EncryptedKey: encryptedKey, +// KeyHash: hash, +// Created: time.Now(), +// Expire: sql.NullTime{Time: time.Now(), Valid: false}, +// RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, +// }, nil) +// +// crypto := func(key key.Key) Encrypter { +// return encrypter +// } +// +// manager := NewEntryKeyManager(db, model, hasher, crypto) +// entryKey, key, err := manager.Create(ctx, entryUUID, dek, nil, nil) +// +// model.AssertExpectations(t) +// hasher.AssertExpectations(t) +// encrypter.AssertExpectations(t) +// if sqlMock.ExpectationsWereMet() != nil { +// t.Errorf("there were unfulfilled expectations: %s", sqlMock.ExpectationsWereMet()) +// } +// assert.NoError(t, err) +// assert.Equal(t, "test-uuid", entryKey.UUID) +// // key.Get should not return an empty string +// assert.NotEmpty(t, key.Get()) +// } func TestEntryKeyManager_GetDEK(t *testing.T) { db, sqlMock, err := sqlmock.New() @@ -461,17 +461,17 @@ func TestEntryKeyManager_GenerateEncryptionKey(t *testing.T) { hasher.On("Hash", dek).Return(hash) encrypter.On("Encrypt", mock.Anything).Return(newEncryptedKey, nil) - model.On("Create", ctx, mock.Anything, entryUUID, newEncryptedKey, hash).Return(&models.EntryKey{ + model.On("Create", ctx, mock.Anything, entryUUID, newEncryptedKey, hash, expire, maxRead).Return(&models.EntryKey{ UUID: "new-test-uuid", EntryUUID: entryUUID, EncryptedKey: newEncryptedKey, KeyHash: hash, Created: time.Now(), - Expire: sql.NullTime{Time: time.Now(), Valid: false}, - RemainingReads: sql.NullInt16{Int16: 0, Valid: false}, + Expire: sql.NullTime{Time: expire, Valid: false}, + RemainingReads: sql.NullInt16{Int16: int16(maxRead), Valid: false}, }, nil) - model.On("SetExpire", ctx, mock.Anything, "new-test-uuid", expire).Return(nil) - model.On("SetMaxReads", ctx, mock.Anything, "new-test-uuid", maxRead).Return(nil) + // model.On("SetExpire", ctx, mock.Anything, "new-test-uuid", expire).Return(nil) + // model.On("SetMaxReads", ctx, mock.Anything, "new-test-uuid", maxRead).Return(nil) crypto := func(key key.Key) Encrypter { return encrypter @@ -479,7 +479,7 @@ func TestEntryKeyManager_GenerateEncryptionKey(t *testing.T) { manager := NewEntryKeyManager(db, model, hasher, crypto) - entryKey, key, err := manager.GenerateEncryptionKey(ctx, entryUUID, encryptedKey, &expire, &maxRead) + entryKey, key, err := manager.GenerateEncryptionKey(ctx, entryUUID, encryptedKey, expire, maxRead) model.AssertExpectations(t) hasher.AssertExpectations(t) diff --git a/internal/services/entrymanager.go b/internal/services/entrymanager.go index 164bf9a..2a30aa9 100644 --- a/internal/services/entrymanager.go +++ b/internal/services/entrymanager.go @@ -92,7 +92,7 @@ func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data } return nil, nil, errors.Join(ErrCreateEntryFailed, err) } - meta, err := e.model.CreateEntry(ctx, tx, uid, contentType, encryptedData, remainingReads, expire) + meta, err := e.model.CreateEntry(ctx, tx, uid, contentType, encryptedData) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { return nil, nil, errors.Join(ErrCreateEntryFailed, err, rollbackErr) @@ -101,7 +101,7 @@ func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data } expireAt := time.Now().Add(expire) - _, kek, err := e.keyManager.CreateWithTx(ctx, tx, uid, dek.Get(), &expireAt, &remainingReads) + entryKey, kek, err := e.keyManager.CreateWithTx(ctx, tx, uid, dek.Get(), expireAt, remainingReads) if err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { @@ -116,23 +116,15 @@ func (e *EntryManager) CreateEntry(ctx context.Context, contentType string, data return &EntryMeta{ UUID: meta.UUID, - RemainingReads: meta.RemainingReads, DeleteKey: meta.DeleteKey, Created: meta.Created, Accessed: meta.Accessed.Time, - Expire: meta.Expire, ContentType: meta.ContentType, + RemainingReads: entryKey.RemainingReads, + Expire: entryKey.Expire, }, kek, nil } -func (e *EntryManager) readEntryLegacy(ctx context.Context, k key.Key, entry *models.Entry) ([]byte, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - crypto := e.crypto(k) - return crypto.Decrypt(entry.Data) -} - // ReadEntry reads an entry // It reads the entry from the database // It reads the key from the key manager @@ -171,15 +163,10 @@ func (e *EntryManager) ReadEntry(ctx context.Context, UUID string, k key.Key) (* var decryptedData []byte if err != nil { if errors.Is(err, ErrEntryKeyNotFound) { - legacyData, legacyErr := e.readEntryLegacy(ctx, k, entry) - if legacyErr == nil { - decryptedData = legacyData - } else { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return nil, errors.Join(err, rollbackErr, ErrReadEntryFailed) - } - return nil, errors.Join(err, ErrReadEntryFailed) + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return nil, errors.Join(err, rollbackErr, ErrReadEntryFailed) } + return nil, errors.Join(err, ErrReadEntryFailed) } else { if rollbackErr := tx.Rollback(); rollbackErr != nil { return nil, errors.Join(err, rollbackErr, ErrReadEntryFailed) @@ -218,12 +205,12 @@ func (e *EntryManager) ReadEntry(ctx context.Context, UUID string, k key.Key) (* return &Entry{ EntryMeta: EntryMeta{ UUID: entry.UUID, - RemainingReads: entry.RemainingReads - 1, DeleteKey: entry.DeleteKey, Created: entry.Created, Accessed: entry.Accessed.Time, - Expire: entry.Expire, ContentType: entry.ContentType, + Expire: entryKey.Expire, + RemainingReads: entryKey.RemainingReads, }, Data: decryptedData, }, nil @@ -273,8 +260,8 @@ func (e *EntryManager) DeleteExpired(ctx context.Context) error { return nil } -func (e *EntryManager) GenerateEntryKey(ctx context.Context, entryUUID string, k key.Key) (*EntryKeyData, error) { - meta, kek, err := e.keyManager.GenerateEncryptionKey(ctx, entryUUID, k, nil, nil) +func (e *EntryManager) GenerateEntryKey(ctx context.Context, entryUUID string, k key.Key, expire time.Duration, maxReads int) (*EntryKeyData, error) { + meta, kek, err := e.keyManager.GenerateEncryptionKey(ctx, entryUUID, k, time.Now().Add(expire), maxReads) if err != nil { return nil, err } diff --git a/internal/services/entrymanager_test.go b/internal/services/entrymanager_test.go index 41604ea..61f01ba 100644 --- a/internal/services/entrymanager_test.go +++ b/internal/services/entrymanager_test.go @@ -32,13 +32,11 @@ func Test_EntryService_Create(t *testing.T) { encryptedData := []byte("encrypted") entryModel := new(models.MockEntryModel) entryModel. - On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData, 1, mock.Anything). + On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData). Return(&models.EntryMeta{ - UUID: "uuid", - RemainingReads: 1, - DeleteKey: "delete_key", - Created: timenow, - Expire: timenow.Add(time.Minute), + UUID: "uuid", + DeleteKey: "delete_key", + Created: timenow, }, nil) entryCrypto := new(MockEntryCrypto) @@ -53,7 +51,10 @@ func Test_EntryService_Create(t *testing.T) { if err != nil { t.Fatal(err) } - keyManager.On("CreateWithTx", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&EntryKey{}, *kek, nil) + keyManager.On("CreateWithTx", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&EntryKey{ + RemainingReads: 1, + Expire: time.Now().Add(time.Minute), + }, *kek, nil) service := NewEntryManager(db, entryModel, crypto, keyManager) meta, key, err := service.CreateEntry(ctx, "text/plain", data, 1, time.Minute) @@ -103,7 +104,7 @@ func TestCreateError(t *testing.T) { entryModel := new(models.MockEntryModel) entryModel. - On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData, 1, mock.Anything). + On("CreateEntry", ctx, mock.Anything, mock.Anything, encryptedData). Return(&models.EntryMeta{}, fmt.Errorf("error")) entryCrypto := new(MockEntryCrypto) @@ -141,12 +142,10 @@ func TestReadEntry(t *testing.T) { entry := models.Entry{ EntryMeta: models.EntryMeta{ - UUID: "uuid", - RemainingReads: 1, - DeleteKey: "delete_key", - Created: timenow, - Accessed: sql.NullTime{Time: timenow, Valid: true}, - Expire: timenow.Add(time.Minute), + UUID: "uuid", + DeleteKey: "delete_key", + Created: timenow, + Accessed: sql.NullTime{Time: timenow, Valid: true}, }, Data: []byte("encrypted"), } @@ -196,7 +195,6 @@ func TestReadEntry(t *testing.T) { DeleteKey: entry.DeleteKey, Created: entry.Created, Accessed: entry.Accessed.Time, - Expire: entry.Expire, ContentType: entry.ContentType, }, }, *data) @@ -237,64 +235,11 @@ func TestReadEntry(t *testing.T) { assert.Error(t, err) assert.Nil(t, data) - entryModel.AssertExpectations(t) keyManager.AssertExpectations(t) - }) - - t.Run("it should try to decrypt with legacy method when key not found", func(t *testing.T) { - db, sqlMock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer db.Close() - - sqlMock.ExpectBegin() - sqlMock.ExpectCommit() - - ctx := context.Background() - - entry := models.Entry{ - EntryMeta: models.EntryMeta{ - UUID: "uuid", - RemainingReads: 1, - DeleteKey: "delete_key", - Created: timenow, - Accessed: sql.NullTime{Time: timenow, Valid: true}, - Expire: timenow.Add(time.Minute), - }, - Data: []byte("encrypted"), - } - - entryModel := new(models.MockEntryModel) - entryModel. - On("ReadEntry", ctx, mock.Anything, "uuid"). - Return(&entry, nil) - entryModel.On("Use", ctx, mock.Anything, "uuid").Return(nil) - - entryCrypto := new(MockEntryCrypto) - entryCrypto.On("Decrypt", []byte("encrypted")).Return([]byte("decrypted"), nil) - - crypto := func(key key.Key) Encrypter { - return entryCrypto - } - - var emptyEntryKey *EntryKey - var emptyDEK key.Key - - k, err := key.NewGeneratedKey() - assert.NoError(t, err) - keyManager := new(MockEntryKeyer) - keyManager.On("GetDEKTx", ctx, mock.Anything, "uuid", *k).Return(emptyDEK, emptyEntryKey, ErrEntryKeyNotFound) - - service := NewEntryManager(db, entryModel, crypto, keyManager) - data, err := service.ReadEntry(ctx, "uuid", *k) - - assert.Nil(t, err) - assert.Equal(t, "decrypted", string(data.Data)) - entryModel.AssertExpectations(t) keyManager.AssertExpectations(t) }) + } func TestReadEntryError(t *testing.T) { @@ -525,17 +470,20 @@ func Test_EntryManager_GenerateEntryKey(t *testing.T) { t.Fatal(err) } + expire := time.Minute + remainingReads := 1 + keyManager := new(MockEntryKeyer) keyManager.On("GenerateEncryptionKey", mock.Anything, entryUUID, *dek, mock.Anything, mock.Anything). Return(&EntryKey{ EntryUUID: entryUUID, - RemainingReads: 1, - Expire: time.Now().Add(time.Minute), + RemainingReads: remainingReads, + Expire: time.Now().Add(expire), }, *kek, nil) service := NewEntryManager(nil, nil, nil, keyManager) - entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek) + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, expire, remainingReads) assert.NoError(t, err) assert.Equal(t, entryUUID, entryKey.EntryUUID) @@ -556,9 +504,12 @@ func Test_EntryManager_GenerateEntryKey(t *testing.T) { keyManager.On("GenerateEncryptionKey", mock.Anything, entryUUID, *dek, mock.Anything, mock.Anything). Return(emptyEntryKey, emptyKey, fmt.Errorf("error")) + expire := time.Minute + remainingReads := 1 + service := NewEntryManager(nil, nil, nil, keyManager) - entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek) + entryKey, err := service.GenerateEntryKey(context.Background(), entryUUID, *dek, expire, remainingReads) assert.Error(t, err) assert.Nil(t, entryKey) diff --git a/internal/services/entryvalidation.go b/internal/services/entryvalidation.go index e8e9297..d222c8c 100644 --- a/internal/services/entryvalidation.go +++ b/internal/services/entryvalidation.go @@ -11,14 +11,6 @@ func validateEntry(entry *models.Entry) error { return ErrEntryNotFound } - if entry.Expire.Before(time.Now()) { - return ErrEntryExpired - } - - if entry.RemainingReads <= 0 { - return ErrEntryExpired - } - return nil } diff --git a/internal/services/interfaces.go b/internal/services/interfaces.go index 3d8edc9..7b4e27d 100644 --- a/internal/services/interfaces.go +++ b/internal/services/interfaces.go @@ -12,7 +12,7 @@ import ( // EntryModel is the interface for the entry model // It is used to create, read and access entries type EntryModel interface { - CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, contentType string, data []byte, remainingReads int, expire time.Duration) (*models.EntryMeta, error) + CreateEntry(ctx context.Context, tx *sql.Tx, UUID string, contentType string, data []byte) (*models.EntryMeta, error) ReadEntry(ctx context.Context, tx *sql.Tx, UUID string) (*models.Entry, error) Use(ctx context.Context, tx *sql.Tx, UUID string) error DeleteEntry(ctx context.Context, tx *sql.Tx, UUID string, deleteKey string) error @@ -22,9 +22,9 @@ type EntryModel interface { // EntryKeyer is the interface for the entry key manager // It is used to create, read and access entry keys type EntryKeyer interface { - CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (entryKey *EntryKey, kek key.Key, err error) + CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (entryKey *EntryKey, kek key.Key, err error) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID string, kek key.Key) (dek key.Key, entryKey *EntryKey, err error) - GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) + GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) UseTx(ctx context.Context, tx *sql.Tx, entryUUID string) error } diff --git a/internal/services/mocks.go b/internal/services/mocks.go index bd3e5f3..32e768e 100644 --- a/internal/services/mocks.go +++ b/internal/services/mocks.go @@ -13,7 +13,7 @@ type MockEntryKeyer struct { mock.Mock } -func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { args := m.Called(ctx, entryUUID, dek, expire, maxRead) if args.Get(1) == nil { return args.Get(0).(*EntryKey), nil, args.Error(2) @@ -21,7 +21,7 @@ func (m *MockEntryKeyer) Create(ctx context.Context, entryUUID string, dek key.K return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) } -func (m *MockEntryKeyer) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) CreateWithTx(ctx context.Context, tx *sql.Tx, entryUUID string, dek key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { args := m.Called(ctx, tx, entryUUID, dek, expire, maxRead) return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) } @@ -36,7 +36,7 @@ func (m *MockEntryKeyer) GetDEKTx(ctx context.Context, tx *sql.Tx, entryUUID str return args.Get(0).(key.Key), args.Get(1).(*EntryKey), args.Error(2) } -func (m *MockEntryKeyer) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire *time.Time, maxRead *int) (*EntryKey, key.Key, error) { +func (m *MockEntryKeyer) GenerateEncryptionKey(ctx context.Context, entryUUID string, existingKey key.Key, expire time.Time, maxRead int) (*EntryKey, key.Key, error) { args := m.Called(ctx, entryUUID, existingKey, expire, maxRead) return args.Get(0).(*EntryKey), args.Get(1).(key.Key), args.Error(2) } diff --git a/internal/views/entrycreate.go b/internal/views/entrycreate.go index 6dd6d57..1956140 100644 --- a/internal/views/entrycreate.go +++ b/internal/views/entrycreate.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "net/url" + "os" "strings" "time" @@ -68,6 +69,7 @@ func (e EntryCreateView) Render(w http.ResponseWriter, r *http.Request, entry En } func (e EntryCreateView) RenderError(w http.ResponseWriter, r *http.Request, err error) { + fmt.Fprintf(os.Stderr, "Render error: %s", err) if errors.Is(err, parsers.ErrInvalidExpirationDate) { http.Error(w, "Invalid expiration", http.StatusBadRequest) return diff --git a/internal/views/entrydelete.go b/internal/views/entrydelete.go index 37970fa..637d283 100644 --- a/internal/views/entrydelete.go +++ b/internal/views/entrydelete.go @@ -2,7 +2,9 @@ package views import ( "errors" + "fmt" "net/http" + "os" "github.com/Ajnasz/sekret.link/internal/models" "github.com/Ajnasz/sekret.link/internal/parsers" @@ -21,6 +23,7 @@ func (e EntryDeleteView) Render(w http.ResponseWriter, r *http.Request, data Del } func (e EntryDeleteView) RenderError(w http.ResponseWriter, r *http.Request, err error) { + fmt.Fprintf(os.Stderr, "Render error: %s", err) if errors.Is(err, models.ErrEntryNotFound) { http.Error(w, "Not Found", http.StatusNotFound) diff --git a/internal/views/entrykeycreate.go b/internal/views/entrykeycreate.go index 9f05497..ff78cea 100644 --- a/internal/views/entrykeycreate.go +++ b/internal/views/entrykeycreate.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "net/url" + "os" "time" "github.com/Ajnasz/sekret.link/internal/key" @@ -59,6 +60,7 @@ func (g GenerateEntryKeyView) Render(w http.ResponseWriter, r *http.Request, res // RenderGenerateEntryKeyError renders the error response for the GenerateEntryKey endpoint. func (v GenerateEntryKeyView) RenderError(w http.ResponseWriter, r *http.Request, err error) { + fmt.Fprintf(os.Stderr, "Render error: %s", err) if errors.Is(err, parsers.ErrInvalidUUID) { http.Error(w, "Invalid UUID", http.StatusBadRequest) return