Skip to content

Commit

Permalink
fix: only use transaction connection in a transaction (#1598)
Browse files Browse the repository at this point in the history
* fix: only use transaction connection in a transaction

* test: fix webhook tests
  • Loading branch information
FreddyDevelop authored Aug 28, 2024
1 parent e72c112 commit 38a11de
Show file tree
Hide file tree
Showing 17 changed files with 46 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (a PasswordLogin) Execute(c flowpilot.ExecutionContext) error {
var userID uuid.UUID

if c.Stash().Get(shared.StashPathEmail).Exists() {
emailModel, err := deps.Persister.GetEmailPersister().FindByAddress(c.Stash().Get(shared.StashPathEmail).String())
emailModel, err := deps.Persister.GetEmailPersisterWithConnection(deps.Tx).FindByAddress(c.Stash().Get(shared.StashPathEmail).String())
if err != nil {
return fmt.Errorf("failed to find user by email: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (a ReSendPasscode) Execute(c flowpilot.ExecutionContext) error {
EmailAddress: c.Stash().Get(shared.StashPathEmail).String(),
Language: deps.HttpContext.Request().Header.Get("Accept-Language"),
}
passcodeResult, err := deps.PasscodeService.SendPasscode(sendParams)
passcodeResult, err := deps.PasscodeService.SendPasscode(deps.Tx, sendParams)
if err != nil {
return fmt.Errorf("passcode service failed: %w", err)
}
Expand All @@ -79,7 +79,7 @@ func (a ReSendPasscode) Execute(c flowpilot.ExecutionContext) error {
},
}

err = utils.TriggerWebhooks(deps.HttpContext, events.EmailSend, webhookData)
err = utils.TriggerWebhooks(deps.HttpContext, deps.Tx, events.EmailSend, webhookData)
if err != nil {
return fmt.Errorf("failed to trigger webhook: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions backend/flow_api/flow/credential_usage/hook_send_passcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (h SendPasscode) Execute(c flowpilot.HookExecutionContext) error {
Language: deps.HttpContext.Request().Header.Get("Accept-Language"),
}

passcodeResult, err := deps.PasscodeService.SendPasscode(sendParams)
passcodeResult, err := deps.PasscodeService.SendPasscode(deps.Tx, sendParams)
if err != nil {
return fmt.Errorf("passcode service failed: %w", err)
}
Expand Down Expand Up @@ -100,7 +100,7 @@ func (h SendPasscode) Execute(c flowpilot.HookExecutionContext) error {
},
}

err = utils.TriggerWebhooks(deps.HttpContext, events.EmailSend, webhookData)
err = utils.TriggerWebhooks(deps.HttpContext, deps.Tx, events.EmailSend, webhookData)
if err != nil {
return fmt.Errorf("failed to trigger webhook: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/flow_api/flow/profile/action_account_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (a AccountDelete) Execute(c flowpilot.ExecutionContext) error {

deps.HttpContext.SetCookie(cookie)

err = utils.TriggerWebhooks(deps.HttpContext, events.UserDelete, admin.FromUserModel(*userModel))
err = utils.TriggerWebhooks(deps.HttpContext, deps.Tx, events.UserDelete, admin.FromUserModel(*userModel))
if err != nil {
return fmt.Errorf("failed to trrigger webhook: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (a RegisterLoginIdentifier) Execute(c flowpilot.ExecutionContext) error {

// Check that email is not already taken
// this check is non-exhaustive as the email is not blocked here and might be created after the check here and the user creation
emailModel, err := deps.Persister.GetEmailPersister().FindByAddress(email)
emailModel, err := deps.Persister.GetEmailPersisterWithConnection(deps.Tx).FindByAddress(email)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions backend/flow_api/services/passcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type SendPasscodeResult struct {

type Passcode interface {
ValidatePasscode(ValidatePasscodeParams) (bool, error)
SendPasscode(SendPasscodeParams) (*SendPasscodeResult, error)
SendPasscode(*pop.Connection, SendPasscodeParams) (*SendPasscodeResult, error)
VerifyPasscodeCode(tx *pop.Connection, passcodeID uuid.UUID, passcode string) error
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func (s *passcode) VerifyPasscodeCode(tx *pop.Connection, passcodeID uuid.UUID,
return nil
}

func (s *passcode) SendPasscode(p SendPasscodeParams) (*SendPasscodeResult, error) {
func (s *passcode) SendPasscode(tx *pop.Connection, p SendPasscodeParams) (*SendPasscodeResult, error) {
code, err := s.passcodeGenerator.Generate()
if err != nil {
return nil, err
Expand All @@ -135,7 +135,7 @@ func (s *passcode) SendPasscode(p SendPasscodeParams) (*SendPasscodeResult, erro
UpdatedAt: now,
}

err = s.persister.GetPasscodePersister().Create(passcodeModel)
err = s.persister.GetPasscodePersisterWithConnection(tx).Create(passcodeModel)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion backend/handler/admin_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewAdminRouter(cfg *config.Config, persister persistence.Persister, prometh
panic(fmt.Errorf("failed to create jwk manager: %w", err))
}

webhookMiddleware := hankoMiddleware.WebhookMiddleware(cfg, jwkManager, persister.GetWebhookPersister(nil))
webhookMiddleware := hankoMiddleware.WebhookMiddleware(cfg, jwkManager, persister)

userHandler := NewUserHandlerAdmin(persister)
emailHandler := NewEmailAdminHandler(cfg, persister)
Expand Down
4 changes: 2 additions & 2 deletions backend/handler/passcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ func (h *PasscodeHandler) Init(c echo.Context) error {
return fmt.Errorf("failed to send passcode: %w", err)
}

err = utils.TriggerWebhooks(c, events.EmailSend, webhookData)
err = utils.TriggerWebhooks(c, h.persister.GetConnection(), events.EmailSend, webhookData)

if err != nil {
zeroLogger.Warn().Err(err).Msg("failed to trigger webhook")
}
} else {
webhookData.DeliveredByHanko = false
err = utils.TriggerWebhooks(c, events.EmailSend, webhookData)
err = utils.TriggerWebhooks(c, h.persister.GetConnection(), events.EmailSend, webhookData)

if err != nil {
return fmt.Errorf(fmt.Sprintf("failed to trigger webhook: %s", err))
Expand Down
2 changes: 1 addition & 1 deletion backend/handler/public_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func NewPublicRouter(cfg *config.Config, persister persistence.Persister, promet

sessionMiddleware := hankoMiddleware.Session(cfg, sessionManager)

webhookMiddleware := hankoMiddleware.WebhookMiddleware(cfg, jwkManager, persister.GetWebhookPersister(nil))
webhookMiddleware := hankoMiddleware.WebhookMiddleware(cfg, jwkManager, persister)

e.POST("/registration", flowAPIHandler.RegistrationFlowHandler, webhookMiddleware)
e.POST("/login", flowAPIHandler.LoginFlowHandler, webhookMiddleware)
Expand Down
2 changes: 1 addition & 1 deletion backend/handler/thirdparty.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (h *ThirdPartyHandler) Callback(c echo.Context) error {
}

if accountLinkingResult.WebhookEvent != nil {
err = webhookUtils.TriggerWebhooks(c, *accountLinkingResult.WebhookEvent, admin.FromUserModel(*accountLinkingResult.User))
err = webhookUtils.TriggerWebhooks(c, h.persister.GetConnection(), *accountLinkingResult.WebhookEvent, admin.FromUserModel(*accountLinkingResult.User))
if err != nil {
c.Logger().Warn(err)
}
Expand Down
4 changes: 2 additions & 2 deletions backend/handler/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (h *UserHandler) Create(c echo.Context) error {
}

if !h.cfg.Email.RequireVerification {
err = utils.TriggerWebhooks(c, events.UserCreate, admin.FromUserModel(newUser))
err = utils.TriggerWebhooks(c, tx, events.UserCreate, admin.FromUserModel(newUser))
if err != nil {
c.Logger().Warn(err)
}
Expand Down Expand Up @@ -288,7 +288,7 @@ func (h *UserHandler) Delete(c echo.Context) error {

c.SetCookie(cookie)

err = utils.TriggerWebhooks(c, events.UserDelete, admin.FromUserModel(*user))
err = utils.TriggerWebhooks(c, tx, events.UserDelete, admin.FromUserModel(*user))
if err != nil {
c.Logger().Warn(err)
}
Expand Down
4 changes: 2 additions & 2 deletions backend/handler/user_admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (h *UserHandlerAdmin) Delete(c echo.Context) error {
return fmt.Errorf("failed to delete user: %w", err)
}

err = utils.TriggerWebhooks(c, events.UserDelete, admin.FromUserModel(*user))
err = utils.TriggerWebhooks(c, h.persister.GetConnection(), events.UserDelete, admin.FromUserModel(*user))
if err != nil {
c.Logger().Warn(err)
}
Expand Down Expand Up @@ -284,7 +284,7 @@ func (h *UserHandlerAdmin) Create(c echo.Context) error {

userDto := admin.FromUserModel(*user)

err = utils.TriggerWebhooks(c, events.UserCreate, userDto)
err = utils.TriggerWebhooks(c, h.persister.GetConnection(), events.UserCreate, userDto)
if err != nil {
c.Logger().Warn(err)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/middleware/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/teamhanko/hanko/backend/webhooks"
)

func WebhookMiddleware(cfg *config.Config, jwkManager hankoJwk.Manager, persister persistence.WebhookPersister) echo.MiddlewareFunc {
func WebhookMiddleware(cfg *config.Config, jwkManager hankoJwk.Manager, persister persistence.Persister) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {

Expand Down
13 changes: 7 additions & 6 deletions backend/webhooks/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package webhooks

import (
"fmt"
"github.com/gobuffalo/pop/v6"
"github.com/labstack/echo/v4"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/teamhanko/hanko/backend/config"
Expand All @@ -13,7 +14,7 @@ import (
)

type Manager interface {
Trigger(evt events.Event, data interface{})
Trigger(tx *pop.Connection, evt events.Event, data interface{})
GenerateJWT(data interface{}, event events.Event) (string, error)
}

Expand All @@ -22,11 +23,11 @@ type manager struct {
webhooks Webhooks
jwtGenerator hankoJwt.Generator
audience []string
persister persistence.WebhookPersister
persister persistence.Persister
canExpireAtTime bool
}

func NewManager(cfg *config.Config, persister persistence.WebhookPersister, jwkManager hankoJwk.Manager, logger echo.Logger) (Manager, error) {
func NewManager(cfg *config.Config, persister persistence.Persister, jwkManager hankoJwk.Manager, logger echo.Logger) (Manager, error) {
hooks := make(Webhooks, 0)

if cfg.Webhooks.Enabled {
Expand Down Expand Up @@ -73,17 +74,17 @@ func NewManager(cfg *config.Config, persister persistence.WebhookPersister, jwkM
}, nil
}

func (m *manager) Trigger(evt events.Event, data interface{}) {
func (m *manager) Trigger(tx *pop.Connection, evt events.Event, data interface{}) {
// add db hooks - Done here to prevent a restart in case a hook is added or removed from the database
dbHooks, err := m.persister.List(false)
dbHooks, err := m.persister.GetWebhookPersister(tx).List(false)
if err != nil {
m.logger.Error(fmt.Errorf("unable to get database webhooks: %w", err))
return
}

hooks := m.webhooks
for _, dbHook := range dbHooks {
hooks = append(hooks, NewDatabaseHook(dbHook, m.persister, m.logger))
hooks = append(hooks, NewDatabaseHook(dbHook, m.persister.GetWebhookPersister(nil), m.logger))
}

dataToken, err := m.GenerateJWT(data, evt)
Expand Down
24 changes: 12 additions & 12 deletions backend/webhooks/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (s *managerSuite) TestNewManager() {
cfg := config.Config{}
jwkManager := test.JwkManager{}

manager, err := NewManager(&cfg, s.Storage.GetWebhookPersister(nil), jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.NoError(err)
s.NotEmpty(manager)
}
Expand All @@ -36,7 +36,7 @@ func (s *managerSuite) TestManager_GenerateJWT() {
cfg := config.Config{}
jwkManager := test.JwkManager{}

manager, err := NewManager(&cfg, s.Storage.GetWebhookPersister(nil), jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)

testData := "lorem-ipsum"

Expand All @@ -55,10 +55,10 @@ func (s *managerSuite) TestManager_TriggerWithoutHook() {
cfg := config.Config{}
jwkManager := test.JwkManager{}

manager, err := NewManager(&cfg, s.Storage.GetWebhookPersister(nil), jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.Require().NoError(err)

manager.Trigger(events.UserCreate, "lorem-ipsum")
manager.Trigger(s.Storage.GetConnection(), events.UserCreate, "lorem-ipsum")

// give it 1 sec to trigger
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -87,10 +87,10 @@ func (s *managerSuite) TestManager_TriggerWithConfigHook() {
}

jwkManager := test.JwkManager{}
manager, err := NewManager(&cfg, s.Storage.GetWebhookPersister(nil), jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.Require().NoError(err)

manager.Trigger(events.UserCreate, "lorem-ipsum")
manager.Trigger(s.Storage.GetConnection(), events.UserCreate, "lorem-ipsum")

// give it 1 sec to trigger
time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -120,10 +120,10 @@ func (s *managerSuite) TestManager_TriggerWithDisabledConfigHook() {
}

jwkManager := test.JwkManager{}
manager, err := NewManager(&cfg, s.Storage.GetWebhookPersister(nil), jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.Require().NoError(err)

manager.Trigger(events.UserCreate, "lorem-ipsum")
manager.Trigger(s.Storage.GetConnection(), events.UserCreate, "lorem-ipsum")

// give it 1 sec to trigger
time.Sleep(1 * time.Second)
Expand All @@ -145,10 +145,10 @@ func (s *managerSuite) TestManager_TriggerWithDbHook() {

s.createTestDatabaseWebhook(persister, true, server.URL)

manager, err := NewManager(&cfg, persister, jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.Require().NoError(err)

manager.Trigger(events.UserCreate, "lorem-ipsum")
manager.Trigger(s.Storage.GetConnection(), events.UserCreate, "lorem-ipsum")

// give it 1 sec to trigger
time.Sleep(1 * time.Second)
Expand All @@ -169,10 +169,10 @@ func (s *managerSuite) TestManager_TriggerWithDisabledDbHook() {

s.createTestDatabaseWebhook(persister, false, server.URL)

manager, err := NewManager(&cfg, persister, jwkManager, nil)
manager, err := NewManager(&cfg, s.Storage, jwkManager, nil)
s.Require().NoError(err)

manager.Trigger(events.UserCreate, "lorem-ipsum")
manager.Trigger(s.Storage.GetConnection(), events.UserCreate, "lorem-ipsum")

// give it 1 sec to trigger
time.Sleep(1 * time.Second)
Expand Down
7 changes: 3 additions & 4 deletions backend/webhooks/utils/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@ import (
"github.com/teamhanko/hanko/backend/webhooks/events"
)

func TriggerWebhooks(ctx echo.Context, evt events.Event, data interface{}) error {
func TriggerWebhooks(ctx echo.Context, tx *pop.Connection, evt events.Event, data interface{}) error {
webhookCtx := ctx.Get("webhook_manager")
if webhookCtx == nil {
return fmt.Errorf("unable to load webhooks manager from webhook middleware")
}

webhookManager := webhookCtx.(webhooks.Manager)
webhookManager.Trigger(evt, data)
webhookManager.Trigger(tx, evt, data)

return nil

}

func NotifyUserChange(ctx echo.Context, tx *pop.Connection, persister persistence.Persister, event events.Event, userId uuid.UUID) {
Expand All @@ -31,7 +30,7 @@ func NotifyUserChange(ctx echo.Context, tx *pop.Connection, persister persistenc
return
}

err = TriggerWebhooks(ctx, event, admin.FromUserModel(*updatedUser))
err = TriggerWebhooks(ctx, tx, event, admin.FromUserModel(*updatedUser))
if err != nil {
ctx.Logger().Warn(err)
}
Expand Down
7 changes: 4 additions & 3 deletions backend/webhooks/utils/webhook_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"github.com/gobuffalo/pop/v6"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/require"
"github.com/teamhanko/hanko/backend/webhooks/events"
Expand All @@ -13,7 +14,7 @@ type testManager struct {
TestFunc func()
}

func (tm *testManager) Trigger(evt events.Event, data interface{}) {
func (tm *testManager) Trigger(tx *pop.Connection, evt events.Event, data interface{}) {
tm.TestFunc()
}

Expand All @@ -28,7 +29,7 @@ func TestWebhook_TriggerWithoutManager(t *testing.T) {

ctx := e.NewContext(req, rec)

err := TriggerWebhooks(ctx, "user", "lorem")
err := TriggerWebhooks(ctx, nil, "user", "lorem")
require.Error(t, err)

err = e.Close()
Expand All @@ -47,7 +48,7 @@ func TestWebhook_Trigger(t *testing.T) {
ctx := e.NewContext(req, rec)
ctx.Set("webhook_manager", tm)

err := TriggerWebhooks(ctx, "user", "lorem")
err := TriggerWebhooks(ctx, nil, "user", "lorem")
require.NoError(t, err)

err = e.Close()
Expand Down

0 comments on commit 38a11de

Please sign in to comment.