@@ -767,6 +767,58 @@ def send_alt_svc(self, previous_state):
767767 (H2StreamStateMachine .send_on_closed_stream , StreamState .CLOSED ),
768768}
769769
770+ """
771+ Wraps a stream state change function to ensure that we keep
772+ the parent H2Connection's state in sync
773+ """
774+ def sync_state_change (func ):
775+ def wrapper (self , * args , ** kwargs ):
776+ # Collect state at the beginning.
777+ start_state = self .state_machine .state
778+ started_open = self .open
779+ started_closed = not started_open
780+
781+ # Do the state change (if any).
782+ result = func (self , * args , ** kwargs )
783+
784+ # Collect state at the end.
785+ end_state = self .state_machine .state
786+ ended_open = self .open
787+ ended_closed = not ended_open
788+
789+ # If at any point we've tranwsitioned to the CLOSED state
790+ # from any other state, close our stream.
791+ if end_state == StreamState .CLOSED and start_state != end_state :
792+ if self ._close_stream_callback :
793+ self ._close_stream_callback (self .stream_id )
794+ # Clear callback so we only call this once per stream
795+ self ._close_stream_callback = None
796+
797+ # If we were open, but are now closed, decrement
798+ # the open stream count, and call the close callback.
799+ if started_open and ended_closed :
800+ if self ._decrement_open_stream_count_callback :
801+ self ._decrement_open_stream_count_callback (self .stream_id ,
802+ - 1 ,)
803+ # Clear callback so we only call this once per stream
804+ self ._decrement_open_stream_count_callback = None
805+
806+ if self ._close_stream_callback :
807+ self ._close_stream_callback (self .stream_id )
808+ # Clear callback so we only call this once per stream
809+ self ._close_stream_callback = None
810+
811+ # If we were closed, but are now open, increment
812+ # the open stream count.
813+ elif started_closed and ended_open :
814+ if self ._increment_open_stream_count_callback :
815+ self ._increment_open_stream_count_callback (self .stream_id ,
816+ 1 ,)
817+ # Clear callback so we only call this once per stream
818+ self ._increment_open_stream_count_callback = None
819+ return result
820+ return wrapper
821+
770822
771823class H2Stream (object ):
772824 """
@@ -782,18 +834,29 @@ def __init__(self,
782834 stream_id ,
783835 config ,
784836 inbound_window_size ,
785- outbound_window_size ):
837+ outbound_window_size ,
838+ increment_open_stream_count_callback ,
839+ close_stream_callback ,):
786840 self .state_machine = H2StreamStateMachine (stream_id )
787841 self .stream_id = stream_id
788842 self .max_outbound_frame_size = None
789843 self .request_method = None
790844
791- # The current value of the outbound stream flow control window
845+ # The current value of the outbound stream flow control window.
792846 self .outbound_flow_control_window = outbound_window_size
793847
794848 # The flow control manager.
795849 self ._inbound_window_manager = WindowManager (inbound_window_size )
796850
851+ # Callback to increment open stream count for the H2Connection.
852+ self ._increment_open_stream_count_callback = increment_open_stream_count_callback
853+
854+ # Callback to decrement open stream count for the H2Connection.
855+ self ._decrement_open_stream_count_callback = increment_open_stream_count_callback
856+
857+ # Callback to clean up state for the H2Connection once we're closed.
858+ self ._close_stream_callback = close_stream_callback
859+
797860 # The expected content length, if any.
798861 self ._expected_content_length = None
799862
@@ -850,6 +913,7 @@ def closed_by(self):
850913 """
851914 return self .state_machine .stream_closed_by
852915
916+ @sync_state_change
853917 def upgrade (self , client_side ):
854918 """
855919 Called by the connection to indicate that this stream is the initial
@@ -868,6 +932,7 @@ def upgrade(self, client_side):
868932 self .state_machine .process_input (input_ )
869933 return
870934
935+ @sync_state_change
871936 def send_headers (self , headers , encoder , end_stream = False ):
872937 """
873938 Returns a list of HEADERS/CONTINUATION frames to emit as either headers
@@ -917,6 +982,7 @@ def send_headers(self, headers, encoder, end_stream=False):
917982
918983 return frames
919984
985+ @sync_state_change
920986 def push_stream_in_band (self , related_stream_id , headers , encoder ):
921987 """
922988 Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
@@ -941,6 +1007,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):
9411007
9421008 return frames
9431009
1010+ @sync_state_change
9441011 def locally_pushed (self ):
9451012 """
9461013 Mark this stream as one that was pushed by this peer. Must be called
@@ -954,6 +1021,7 @@ def locally_pushed(self):
9541021 assert not events
9551022 return []
9561023
1024+ @sync_state_change
9571025 def send_data (self , data , end_stream = False , pad_length = None ):
9581026 """
9591027 Prepare some data frames. Optionally end the stream.
@@ -981,6 +1049,7 @@ def send_data(self, data, end_stream=False, pad_length=None):
9811049
9821050 return [df ]
9831051
1052+ @sync_state_change
9841053 def end_stream (self ):
9851054 """
9861055 End a stream without sending data.
@@ -992,6 +1061,7 @@ def end_stream(self):
9921061 df .flags .add ('END_STREAM' )
9931062 return [df ]
9941063
1064+ @sync_state_change
9951065 def advertise_alternative_service (self , field_value ):
9961066 """
9971067 Advertise an RFC 7838 alternative service. The semantics of this are
@@ -1005,6 +1075,7 @@ def advertise_alternative_service(self, field_value):
10051075 asf .field = field_value
10061076 return [asf ]
10071077
1078+ @sync_state_change
10081079 def increase_flow_control_window (self , increment ):
10091080 """
10101081 Increase the size of the flow control window for the remote side.
@@ -1020,6 +1091,7 @@ def increase_flow_control_window(self, increment):
10201091 wuf .window_increment = increment
10211092 return [wuf ]
10221093
1094+ @sync_state_change
10231095 def receive_push_promise_in_band (self ,
10241096 promised_stream_id ,
10251097 headers ,
@@ -1044,6 +1116,7 @@ def receive_push_promise_in_band(self,
10441116 )
10451117 return [], events
10461118
1119+ @sync_state_change
10471120 def remotely_pushed (self , pushed_headers ):
10481121 """
10491122 Mark this stream as one that was pushed by the remote peer. Must be
@@ -1057,6 +1130,7 @@ def remotely_pushed(self, pushed_headers):
10571130 self ._authority = authority_from_headers (pushed_headers )
10581131 return [], events
10591132
1133+ @sync_state_change
10601134 def receive_headers (self , headers , end_stream , header_encoding ):
10611135 """
10621136 Receive a set of headers (or trailers).
@@ -1091,6 +1165,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
10911165 )
10921166 return [], events
10931167
1168+ @sync_state_change
10941169 def receive_data (self , data , end_stream , flow_control_len ):
10951170 """
10961171 Receive some data.
@@ -1114,6 +1189,7 @@ def receive_data(self, data, end_stream, flow_control_len):
11141189 events [0 ].flow_controlled_length = flow_control_len
11151190 return [], events
11161191
1192+ @sync_state_change
11171193 def receive_window_update (self , increment ):
11181194 """
11191195 Handle a WINDOW_UPDATE increment.
@@ -1150,6 +1226,7 @@ def receive_window_update(self, increment):
11501226
11511227 return frames , events
11521228
1229+ @sync_state_change
11531230 def receive_continuation (self ):
11541231 """
11551232 A naked CONTINUATION frame has been received. This is always an error,
@@ -1162,6 +1239,7 @@ def receive_continuation(self):
11621239 )
11631240 assert False , "Should not be reachable"
11641241
1242+ @sync_state_change
11651243 def receive_alt_svc (self , frame ):
11661244 """
11671245 An Alternative Service frame was received on the stream. This frame
@@ -1189,6 +1267,7 @@ def receive_alt_svc(self, frame):
11891267
11901268 return [], events
11911269
1270+ @sync_state_change
11921271 def reset_stream (self , error_code = 0 ):
11931272 """
11941273 Close the stream locally. Reset the stream with an error code.
@@ -1202,6 +1281,7 @@ def reset_stream(self, error_code=0):
12021281 rsf .error_code = error_code
12031282 return [rsf ]
12041283
1284+ @sync_state_change
12051285 def stream_reset (self , frame ):
12061286 """
12071287 Handle a stream being reset remotely.
@@ -1217,6 +1297,7 @@ def stream_reset(self, frame):
12171297
12181298 return [], events
12191299
1300+ @sync_state_change
12201301 def acknowledge_received_data (self , acknowledged_size ):
12211302 """
12221303 The user has informed us that they've processed some amount of data
0 commit comments