diff --git a/session.go b/session.go index 79441e1..bf5f680 100644 --- a/session.go +++ b/session.go @@ -21,15 +21,15 @@ import ( // The MemoryManager allows management of memory allocations. type MemoryManager interface { // ReserveMemory reserves memory / buffer. - ReserveMemory(size int, prio uint8) error + ReserveMemory(conn *Session, size int, prio uint8) error // ReleaseMemory explicitly releases memory previously reserved with ReserveMemory - ReleaseMemory(size int) + ReleaseMemory(conn *Session, size int) } type nullMemoryManagerImpl struct{} -func (n nullMemoryManagerImpl) ReserveMemory(size int, prio uint8) error { return nil } -func (n nullMemoryManagerImpl) ReleaseMemory(size int) {} +func (n nullMemoryManagerImpl) ReserveMemory(conn *Session, size int, prio uint8) error { return nil } +func (n nullMemoryManagerImpl) ReleaseMemory(conn *Session, size int) {} var nullMemoryManager MemoryManager = &nullMemoryManagerImpl{} @@ -208,7 +208,7 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { return nil, s.shutdownErr } - if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + if err := s.memoryManager.ReserveMemory(s, initialStreamWindow, 255); err != nil { return nil, err } @@ -759,7 +759,7 @@ func (s *Session) incomingStream(id uint32) error { } // Allocate a new stream - if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + if err := s.memoryManager.ReserveMemory(s, initialStreamWindow, 255); err != nil { return err } stream := newStream(s, id, streamSYNReceived, initialStreamWindow) @@ -773,14 +773,14 @@ func (s *Session) incomingStream(id uint32) error { if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } - s.memoryManager.ReleaseMemory(initialStreamWindow) + s.memoryManager.ReleaseMemory(s, initialStreamWindow) return ErrDuplicateStream } if s.numIncomingStreams >= s.config.MaxIncomingStreams { // too many active streams at the same time s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset") - s.memoryManager.ReleaseMemory(initialStreamWindow) + s.memoryManager.ReleaseMemory(s, initialStreamWindow) hdr := encode(typeWindowUpdate, flagRST, id, 0) return s.sendMsg(hdr, nil, nil) } @@ -827,7 +827,7 @@ func (s *Session) deleteStream(id uint32) { if !ok { return } - s.memoryManager.ReleaseMemory(int(str.recvWindow)) + s.memoryManager.ReleaseMemory(s, int(str.recvWindow)) delete(s.streams, id) } diff --git a/stream.go b/stream.go index 7ae0be3..bbc40de 100644 --- a/stream.go +++ b/stream.go @@ -225,7 +225,7 @@ func (s *Stream) sendWindowUpdate() error { recvWindow = min(s.recvWindow*2, s.session.config.MaxStreamWindowSize) } if recvWindow > s.recvWindow { - if err := s.session.memoryManager.ReserveMemory(int(delta), 128); err == nil { + if err := s.session.memoryManager.ReserveMemory(s.session, int(delta), 128); err == nil { s.recvWindow = recvWindow _, delta = s.recvBuf.GrowTo(s.recvWindow, true) }