@@ -51,9 +51,8 @@ type Conn struct {
5151 br * bufio.Reader
5252 bw * bufio.Writer
5353
54- readTimeout chan context.Context
55- writeTimeout chan context.Context
56- timeoutLoopDone chan struct {}
54+ readTimeoutStop atomic.Pointer [func () bool ]
55+ writeTimeoutStop atomic.Pointer [func () bool ]
5756
5857 // Read state.
5958 readMu * mu
@@ -113,10 +112,6 @@ func newConn(cfg connConfig) *Conn {
113112 br : cfg .br ,
114113 bw : cfg .bw ,
115114
116- readTimeout : make (chan context.Context ),
117- writeTimeout : make (chan context.Context ),
118- timeoutLoopDone : make (chan struct {}),
119-
120115 closed : make (chan struct {}),
121116 activePings : make (map [string ]chan <- struct {}),
122117 onPingReceived : cfg .onPingReceived ,
@@ -144,8 +139,6 @@ func newConn(cfg connConfig) *Conn {
144139 c .close ()
145140 })
146141
147- go c .timeoutLoop ()
148-
149142 return c
150143}
151144
@@ -175,27 +168,34 @@ func (c *Conn) close() error {
175168 return err
176169}
177170
178- func (c * Conn ) timeoutLoop () {
179- defer close (c .timeoutLoopDone )
171+ func (c * Conn ) setupWriteTimeout (ctx context.Context ) {
172+ stop := context .AfterFunc (ctx , func () {
173+ c .clearWriteTimeout ()
174+ c .close ()
175+ })
176+ swapTimeoutStop (& c .writeTimeoutStop , & stop )
177+ }
180178
181- readCtx := context .Background ()
182- writeCtx := context .Background ()
179+ func (c * Conn ) clearWriteTimeout () {
180+ swapTimeoutStop (& c .writeTimeoutStop , nil )
181+ }
183182
184- for {
185- select {
186- case <- c .closed :
187- return
188-
189- case writeCtx = <- c .writeTimeout :
190- case readCtx = <- c .readTimeout :
191-
192- case <- readCtx .Done ():
193- c .close ()
194- return
195- case <- writeCtx .Done ():
196- c .close ()
197- return
198- }
183+ func (c * Conn ) setupReadTimeout (ctx context.Context ) {
184+ stop := context .AfterFunc (ctx , func () {
185+ c .clearReadTimeout ()
186+ c .close ()
187+ })
188+ swapTimeoutStop (& c .readTimeoutStop , & stop )
189+ }
190+
191+ func (c * Conn ) clearReadTimeout () {
192+ swapTimeoutStop (& c .readTimeoutStop , nil )
193+ }
194+
195+ func swapTimeoutStop (p * atomic.Pointer [func () bool ], newStop * func () bool ) {
196+ oldStop := p .Swap (newStop )
197+ if oldStop != nil {
198+ (* oldStop )()
199199 }
200200}
201201
0 commit comments