@@ -767,6 +767,56 @@ def send_alt_svc(self, previous_state):
767767 (H2StreamStateMachine .send_on_closed_stream , StreamState .CLOSED ),
768768}
769769
770+ """
771+ Wraps a stream state change function to ensure that we keep
772+ the parent H2Connection's state in sync
773+ """
774+ def sync_state_change (func ):
775+ def wrapper (self , * args , ** kwargs ):
776+ # Collect state at the beginning.
777+ start_state = self .state_machine .state
778+ started_open = self .open
779+ started_closed = not started_open
780+
781+ # Do the state change (if any).
782+ result = func (self , * args , ** kwargs )
783+
784+ # Collect state at the end.
785+ end_state = self .state_machine .state
786+ ended_open = self .open
787+ ended_closed = not ended_open
788+
789+ if end_state == StreamState .CLOSED and start_state != end_state :
790+ if self ._close_stream_callback :
791+ self ._close_stream_callback (self .stream_id )
792+ # Clear callback so we only call this once per stream
793+ self ._close_stream_callback = None
794+
795+ # If we were open, but are now closed, decrement
796+ # the open stream count, and call the close callback.
797+ if started_open and ended_closed :
798+ if self ._decrement_open_stream_count_callback :
799+ self ._decrement_open_stream_count_callback (self .stream_id ,
800+ - 1 ,)
801+ # Clear callback so we only call this once per stream
802+ self ._decrement_open_stream_count_callback = None
803+
804+ if self ._close_stream_callback :
805+ self ._close_stream_callback (self .stream_id )
806+ # Clear callback so we only call this once per stream
807+ self ._close_stream_callback = None
808+
809+ # If we were closed, but are now open, increment
810+ # the open stream count.
811+ elif started_closed and ended_open :
812+ if self ._increment_open_stream_count_callback :
813+ self ._increment_open_stream_count_callback (self .stream_id ,
814+ 1 ,)
815+ # Clear callback so we only call this once per stream
816+ self ._increment_open_stream_count_callback = None
817+ return result
818+ return wrapper
819+
770820
771821class H2Stream (object ):
772822 """
@@ -782,18 +832,29 @@ def __init__(self,
782832 stream_id ,
783833 config ,
784834 inbound_window_size ,
785- outbound_window_size ):
835+ outbound_window_size ,
836+ increment_open_stream_count_callback ,
837+ close_stream_callback ,):
786838 self .state_machine = H2StreamStateMachine (stream_id )
787839 self .stream_id = stream_id
788840 self .max_outbound_frame_size = None
789841 self .request_method = None
790842
791- # The current value of the outbound stream flow control window
843+ # The current value of the outbound stream flow control window.
792844 self .outbound_flow_control_window = outbound_window_size
793845
794846 # The flow control manager.
795847 self ._inbound_window_manager = WindowManager (inbound_window_size )
796848
849+ # Callback to increment open stream count for the H2Connection.
850+ self ._increment_open_stream_count_callback = increment_open_stream_count_callback
851+
852+ # Callback to decrement open stream count for the H2Connection.
853+ self ._decrement_open_stream_count_callback = increment_open_stream_count_callback
854+
855+ # Callback to clean up state for the H2Connection once we're closed.
856+ self ._close_stream_callback = close_stream_callback
857+
797858 # The expected content length, if any.
798859 self ._expected_content_length = None
799860
@@ -850,6 +911,7 @@ def closed_by(self):
850911 """
851912 return self .state_machine .stream_closed_by
852913
914+ @sync_state_change
853915 def upgrade (self , client_side ):
854916 """
855917 Called by the connection to indicate that this stream is the initial
@@ -868,6 +930,7 @@ def upgrade(self, client_side):
868930 self .state_machine .process_input (input_ )
869931 return
870932
933+ @sync_state_change
871934 def send_headers (self , headers , encoder , end_stream = False ):
872935 """
873936 Returns a list of HEADERS/CONTINUATION frames to emit as either headers
@@ -917,6 +980,7 @@ def send_headers(self, headers, encoder, end_stream=False):
917980
918981 return frames
919982
983+ @sync_state_change
920984 def push_stream_in_band (self , related_stream_id , headers , encoder ):
921985 """
922986 Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
@@ -941,6 +1005,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):
9411005
9421006 return frames
9431007
1008+ @sync_state_change
9441009 def locally_pushed (self ):
9451010 """
9461011 Mark this stream as one that was pushed by this peer. Must be called
@@ -954,6 +1019,7 @@ def locally_pushed(self):
9541019 assert not events
9551020 return []
9561021
1022+ @sync_state_change
9571023 def send_data (self , data , end_stream = False , pad_length = None ):
9581024 """
9591025 Prepare some data frames. Optionally end the stream.
@@ -981,6 +1047,7 @@ def send_data(self, data, end_stream=False, pad_length=None):
9811047
9821048 return [df ]
9831049
1050+ @sync_state_change
9841051 def end_stream (self ):
9851052 """
9861053 End a stream without sending data.
@@ -992,6 +1059,7 @@ def end_stream(self):
9921059 df .flags .add ('END_STREAM' )
9931060 return [df ]
9941061
1062+ @sync_state_change
9951063 def advertise_alternative_service (self , field_value ):
9961064 """
9971065 Advertise an RFC 7838 alternative service. The semantics of this are
@@ -1005,6 +1073,7 @@ def advertise_alternative_service(self, field_value):
10051073 asf .field = field_value
10061074 return [asf ]
10071075
1076+ @sync_state_change
10081077 def increase_flow_control_window (self , increment ):
10091078 """
10101079 Increase the size of the flow control window for the remote side.
@@ -1020,6 +1089,7 @@ def increase_flow_control_window(self, increment):
10201089 wuf .window_increment = increment
10211090 return [wuf ]
10221091
1092+ @sync_state_change
10231093 def receive_push_promise_in_band (self ,
10241094 promised_stream_id ,
10251095 headers ,
@@ -1044,6 +1114,7 @@ def receive_push_promise_in_band(self,
10441114 )
10451115 return [], events
10461116
1117+ @sync_state_change
10471118 def remotely_pushed (self , pushed_headers ):
10481119 """
10491120 Mark this stream as one that was pushed by the remote peer. Must be
@@ -1057,6 +1128,7 @@ def remotely_pushed(self, pushed_headers):
10571128 self ._authority = authority_from_headers (pushed_headers )
10581129 return [], events
10591130
1131+ @sync_state_change
10601132 def receive_headers (self , headers , end_stream , header_encoding ):
10611133 """
10621134 Receive a set of headers (or trailers).
@@ -1091,6 +1163,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
10911163 )
10921164 return [], events
10931165
1166+ @sync_state_change
10941167 def receive_data (self , data , end_stream , flow_control_len ):
10951168 """
10961169 Receive some data.
@@ -1114,6 +1187,7 @@ def receive_data(self, data, end_stream, flow_control_len):
11141187 events [0 ].flow_controlled_length = flow_control_len
11151188 return [], events
11161189
1190+ @sync_state_change
11171191 def receive_window_update (self , increment ):
11181192 """
11191193 Handle a WINDOW_UPDATE increment.
@@ -1150,6 +1224,7 @@ def receive_window_update(self, increment):
11501224
11511225 return frames , events
11521226
1227+ @sync_state_change
11531228 def receive_continuation (self ):
11541229 """
11551230 A naked CONTINUATION frame has been received. This is always an error,
@@ -1162,6 +1237,7 @@ def receive_continuation(self):
11621237 )
11631238 assert False , "Should not be reachable"
11641239
1240+ @sync_state_change
11651241 def receive_alt_svc (self , frame ):
11661242 """
11671243 An Alternative Service frame was received on the stream. This frame
@@ -1189,6 +1265,7 @@ def receive_alt_svc(self, frame):
11891265
11901266 return [], events
11911267
1268+ @sync_state_change
11921269 def reset_stream (self , error_code = 0 ):
11931270 """
11941271 Close the stream locally. Reset the stream with an error code.
@@ -1202,6 +1279,7 @@ def reset_stream(self, error_code=0):
12021279 rsf .error_code = error_code
12031280 return [rsf ]
12041281
1282+ @sync_state_change
12051283 def stream_reset (self , frame ):
12061284 """
12071285 Handle a stream being reset remotely.
@@ -1217,6 +1295,7 @@ def stream_reset(self, frame):
12171295
12181296 return [], events
12191297
1298+ @sync_state_change
12201299 def acknowledge_received_data (self , acknowledged_size ):
12211300 """
12221301 The user has informed us that they've processed some amount of data
0 commit comments