@@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode {
9797//
9898// Close will unblock all goroutines interacting with the connection once
9999// complete.
100- func (c * Conn ) Close (code StatusCode , reason string ) error {
101- defer c .wg .Wait ()
102- return c .closeHandshake (code , reason )
100+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
101+ defer errd .Wrap (& err , "failed to close WebSocket" )
102+
103+ if ! c .casClosing () {
104+ err = c .waitGoroutines ()
105+ if err != nil {
106+ return err
107+ }
108+ return net .ErrClosed
109+ }
110+ defer func () {
111+ if errors .Is (err , net .ErrClosed ) {
112+ err = nil
113+ }
114+ }()
115+
116+ err = c .closeHandshake (code , reason )
117+
118+ err2 := c .close ()
119+ if err == nil && err2 != nil {
120+ err = err2
121+ }
122+
123+ err2 = c .waitGoroutines ()
124+ if err == nil && err2 != nil {
125+ err = err2
126+ }
127+
128+ return err
103129}
104130
105131// CloseNow closes the WebSocket connection without attempting a close handshake.
106132// Use when you do not want the overhead of the close handshake.
107133func (c * Conn ) CloseNow () (err error ) {
108- defer c .wg .Wait ()
109134 defer errd .Wrap (& err , "failed to close WebSocket" )
110135
111- if c .isClosed () {
136+ if ! c .casClosing () {
137+ err = c .waitGoroutines ()
138+ if err != nil {
139+ return err
140+ }
112141 return net .ErrClosed
113142 }
143+ defer func () {
144+ if errors .Is (err , net .ErrClosed ) {
145+ err = nil
146+ }
147+ }()
114148
115- c .close (nil )
116- return c .closeErr
117- }
118-
119- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
120- defer errd .Wrap (& err , "failed to close WebSocket" )
121-
122- writeErr := c .writeClose (code , reason )
123- closeHandshakeErr := c .waitCloseHandshake ()
149+ err = c .close ()
124150
125- if writeErr != nil {
126- return writeErr
151+ err2 := c .waitGoroutines ()
152+ if err == nil && err2 != nil {
153+ err = err2
127154 }
155+ return err
156+ }
128157
129- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
130- return closeHandshakeErr
158+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
159+ err := c .writeClose (code , reason )
160+ if err != nil {
161+ return err
131162 }
132163
164+ err = c .waitCloseHandshake ()
165+ if CloseStatus (err ) != code {
166+ return err
167+ }
133168 return nil
134169}
135170
136171func (c * Conn ) writeClose (code StatusCode , reason string ) error {
137- c .closeMu .Lock ()
138- wroteClose := c .wroteClose
139- c .wroteClose = true
140- c .closeMu .Unlock ()
141- if wroteClose {
142- return net .ErrClosed
143- }
144-
145172 ce := CloseError {
146173 Code : code ,
147174 Reason : reason ,
148175 }
149176
150177 var p []byte
151- var marshalErr error
178+ var err error
152179 if ce .Code != StatusNoStatusRcvd {
153- p , marshalErr = ce .bytes ()
154- }
155-
156- writeErr := c .writeControl (context .Background (), opClose , p )
157- if CloseStatus (writeErr ) != - 1 {
158- // Not a real error if it's due to a close frame being received.
159- writeErr = nil
180+ p , err = ce .bytes ()
181+ if err != nil {
182+ return err
183+ }
160184 }
161185
162- // We do this after in case there was an error writing the close frame.
163- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
186+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
187+ defer cancel ( )
164188
165- if marshalErr != nil {
166- return marshalErr
189+ err = c .writeControl (ctx , opClose , p )
190+ // If the connection closed as we're writing we ignore the error as we might
191+ // have written the close frame, the peer responded and then someone else read it
192+ // and closed the connection.
193+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
194+ return err
167195 }
168- return writeErr
196+ return nil
169197}
170198
171199func (c * Conn ) waitCloseHandshake () error {
172- defer c .close (nil )
173-
174200 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
175201 defer cancel ()
176202
@@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error {
180206 }
181207 defer c .readMu .unlock ()
182208
183- if c .readCloseFrameErr != nil {
184- return c .readCloseFrameErr
185- }
186-
187209 for i := int64 (0 ); i < c .msgReader .payloadLength ; i ++ {
188210 _ , err := c .br .ReadByte ()
189211 if err != nil {
@@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
206228 }
207229}
208230
231+ func (c * Conn ) waitGoroutines () error {
232+ t := time .NewTimer (time .Second * 15 )
233+ defer t .Stop ()
234+
235+ select {
236+ case <- c .timeoutLoopDone :
237+ case <- t .C :
238+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
239+ }
240+
241+ c .closeReadMu .Lock ()
242+ closeRead := c .closeReadCtx != nil
243+ c .closeReadMu .Unlock ()
244+ if closeRead {
245+ select {
246+ case <- c .closeReadDone :
247+ case <- t .C :
248+ return errors .New ("failed to wait for close read goroutine to exit" )
249+ }
250+ }
251+
252+ select {
253+ case <- c .closed :
254+ case <- t .C :
255+ return errors .New ("failed to wait for connection to be closed" )
256+ }
257+
258+ return nil
259+ }
260+
209261func parseClosePayload (p []byte ) (CloseError , error ) {
210262 if len (p ) == 0 {
211263 return CloseError {
@@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
276328 return buf , nil
277329}
278330
279- func (c * Conn ) setCloseErr ( err error ) {
331+ func (c * Conn ) casClosing () bool {
280332 c .closeMu .Lock ()
281- c .setCloseErrLocked (err )
282- c .closeMu .Unlock ()
283- }
284-
285- func (c * Conn ) setCloseErrLocked (err error ) {
286- if c .closeErr == nil && err != nil {
287- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
333+ defer c .closeMu .Unlock ()
334+ if ! c .closing {
335+ c .closing = true
336+ return true
288337 }
338+ return false
289339}
290340
291341func (c * Conn ) isClosed () bool {
0 commit comments