diff --git a/contracts/session/session.go b/contracts/session/session.go index f4273f9eb..3a04b6dae 100644 --- a/contracts/session/session.go +++ b/contracts/session/session.go @@ -1,5 +1,7 @@ package session +import "github.com/goravel/framework/contracts/foundation" + // Session is the interface that defines the methods that should be implemented by a session. type Session interface { // All returns all attributes of the session. @@ -42,8 +44,12 @@ type Session interface { Remove(key string) any // Save saves the session. Save() error + // SetDriver sets the session driver + SetDriver(driver Driver) Session // SetID sets the ID of the session. SetID(id string) Session + // SetJson sets the JSON parser for the session, which is used to marshal and unmarshal data. + SetJson(json foundation.Json) Session // SetName sets the name of the session. SetName(name string) Session // Start initiates the session. diff --git a/session/errors.go b/session/errors.go new file mode 100644 index 000000000..a8904053d --- /dev/null +++ b/session/errors.go @@ -0,0 +1,7 @@ +package session + +import "errors" + +var ( + ErrDriverNotSet = errors.New("session driver is not set") +) diff --git a/session/manager.go b/session/manager.go index 3fea6453c..61c3f35c5 100644 --- a/session/manager.go +++ b/session/manager.go @@ -25,9 +25,7 @@ func NewManager(config config.Config, json foundation.Json) *Manager { drivers: make(map[string]sessioncontract.Driver), json: json, sessionPool: sync.Pool{New: func() any { - return &Session{ - attributes: make(map[string]any), - } + return NewSession("", nil, json) }, }, } @@ -40,9 +38,10 @@ func (m *Manager) BuildSession(handler sessioncontract.Driver, sessionID ...stri panic("session driver cannot be nil") } session := m.acquireSession() - session.setDriver(handler) - session.setJson(m.json) - session.SetName(m.config.GetString("session.cookie")) + session.SetDriver(handler). + SetJson(m.json). + SetName(m.config.GetString("session.cookie")) + if len(sessionID) > 0 { session.SetID(sessionID[0]) } else { @@ -81,13 +80,12 @@ func (m *Manager) Extend(driver string, handler func() sessioncontract.Driver) e } func (m *Manager) ReleaseSession(session sessioncontract.Session) { - s := session.(*Session) - s.reset() - m.sessionPool.Put(s) + session.Flush() + m.sessionPool.Put(session) } -func (m *Manager) acquireSession() *Session { - session := m.sessionPool.Get().(*Session) +func (m *Manager) acquireSession() sessioncontract.Session { + session := m.sessionPool.Get().(sessioncontract.Session) return session } diff --git a/session/session.go b/session/session.go index 7c796e09f..2a64d442c 100644 --- a/session/session.go +++ b/session/session.go @@ -8,6 +8,7 @@ import ( "github.com/goravel/framework/contracts/foundation" sessioncontract "github.com/goravel/framework/contracts/session" + "github.com/goravel/framework/support/color" supportmaps "github.com/goravel/framework/support/maps" "github.com/goravel/framework/support/str" ) @@ -155,6 +156,10 @@ func (s *Session) Save() error { return err } + if err = s.validateDriver(); err != nil { + return err + } + if err = s.driver.Write(s.GetID(), string(data)); err != nil { return err } @@ -164,6 +169,11 @@ func (s *Session) Save() error { return nil } +func (s *Session) SetDriver(driver sessioncontract.Driver) sessioncontract.Session { + s.driver = driver + return s +} + func (s *Session) SetID(id string) sessioncontract.Session { if s.isValidID(id) { s.id = id @@ -174,6 +184,11 @@ func (s *Session) SetID(id string) sessioncontract.Session { return s } +func (s *Session) SetJson(json foundation.Json) sessioncontract.Session { + s.json = json + return s +} + func (s *Session) SetName(name string) sessioncontract.Session { s.name = name @@ -210,6 +225,13 @@ func (s *Session) loadSession() { } } +func (s *Session) validateDriver() error { + if s.driver == nil { + return ErrDriverNotSet + } + return nil +} + func (s *Session) migrate(destroy ...bool) error { shouldDestroy := false if len(destroy) > 0 { @@ -217,8 +239,11 @@ func (s *Session) migrate(destroy ...bool) error { } if shouldDestroy { - err := s.driver.Destroy(s.GetID()) - if err != nil { + if err := s.validateDriver(); err != nil { + return err + } + + if err := s.driver.Destroy(s.GetID()); err != nil { return err } } @@ -229,12 +254,19 @@ func (s *Session) migrate(destroy ...bool) error { } func (s *Session) readFromHandler() map[string]any { + if err := s.validateDriver(); err != nil { + color.Red().Println(err) + return nil + } + value, err := s.driver.Read(s.GetID()) if err != nil { + color.Red().Println(err) return nil } var data map[string]any if err := s.json.Unmarshal([]byte(value), &data); err != nil { + color.Red().Println(err) return nil } return data @@ -280,14 +312,6 @@ func (s *Session) reset() { s.started = false } -func (s *Session) setDriver(driver sessioncontract.Driver) { - s.driver = driver -} - -func (s *Session) setJson(json foundation.Json) { - s.json = json -} - // toStringSlice converts an interface slice to a string slice. func toStringSlice(anySlice []any) []string { strSlice := make([]string, len(anySlice))