@@ -344,84 +344,67 @@ def connection_pattern(self, node):
344344
345345 return [[True for _ in node .outputs ] for _ in node .inputs ]
346346
347- def _bgrad (self , inputs , outputs , ograds ):
348- # Grad, with respect to broadcasted versions of inputs
349-
350- def as_core (t , core_t ):
351- # Inputs could be NullType or DisconnectedType
352- if isinstance (t .type , NullType | DisconnectedType ):
353- return t
354- return core_t .type ()
347+ def L_op (self , inputs , outputs , output_gradients ):
348+ batch_ndim = self .batch_ndim (outputs [0 ].owner )
355349
350+ # Obtain core_op gradients
356351 with config .change_flags (compute_test_value = "off" ):
357- safe_inputs = [
352+ core_inputs = [
358353 tensor (
359354 dtype = inp .type .dtype ,
360- shape = inp .type .shape [inp . type . ndim - len ( sig ) :],
355+ shape = inp .type .shape [batch_ndim :],
361356 )
362- for inp , sig in zip (inputs , self .inputs_sig , strict = True )
363- ]
364- core_node = self ._create_dummy_core_node (safe_inputs )
365-
366- core_inputs = [
367- as_core (inp , core_inp )
368- for inp , core_inp in zip (inputs , core_node .inputs , strict = True )
369- ]
370- core_ograds = [
371- as_core (ograd , core_ograd )
372- for ograd , core_ograd in zip (ograds , core_node .outputs , strict = True )
357+ for inp in inputs
373358 ]
374- # FIXME: These core_outputs do not depend on core_inputs, not pretty
375- # It's not neccessarily a problem because if they are referenced by the gradient,
376- # they get replaced later in vectorize. But if the Op was to make any decision
377- # by introspecting the dependencies of output on inputs it would fail badly!
359+ core_node = self ._create_dummy_core_node (core_inputs )
378360 core_outputs = core_node .outputs
379361
380- core_igrads = self .core_op .L_op (core_inputs , core_outputs , core_ograds )
381-
382- igrads = vectorize_graph (
383- [core_igrad for core_igrad in core_igrads if core_igrad is not None ],
384- replace = dict (
385- zip (
386- core_inputs + core_outputs + core_ograds ,
387- inputs + outputs + ograds ,
388- strict = True ,
362+ # Define core output_gradients, but keep original disconnected/null output_gradients (if any)
363+ core_output_gradients = [
364+ output_grad
365+ if isinstance (output_grad .type , NullType | DisconnectedType )
366+ else core_output .type ()
367+ for output_grad , core_output in zip (
368+ output_gradients , core_outputs , strict = True
389369 )
390- ),
391- )
392-
393- igrads_iter = iter (igrads )
394- return [
395- None if core_igrad is None else next (igrads_iter )
396- for core_igrad in core_igrads
397- ]
370+ ]
398371
399- def L_op (self , inputs , outs , ograds ):
400- from pytensor .tensor .math import sum as pt_sum
372+ core_input_gradients = self .core_op .L_op (
373+ core_inputs , core_outputs , core_output_gradients
374+ )
401375
402- # Compute grad with respect to broadcasted input
403- rval = self ._bgrad (inputs , outs , ograds )
376+ # Vectorize gradients to batch inputs
377+ input_gradients = list (
378+ vectorize_graph (
379+ core_input_gradients ,
380+ replace = dict (
381+ zip (
382+ core_inputs + core_outputs + core_output_gradients ,
383+ inputs + outputs + output_gradients ,
384+ strict = True ,
385+ )
386+ ),
387+ )
388+ )
404389
405- # Sum out the broadcasted dimensions
406- batch_ndims = self .batch_ndim (outs [0 ].owner )
407- batch_shape = outs [0 ].type .shape [:batch_ndims ]
390+ # Sum out the broadcasted batch dimensions
391+ batch_shape = outputs [0 ].type .shape [:batch_ndim ]
408392 for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
409- if isinstance (rval [i ].type , NullType | DisconnectedType ):
393+ if isinstance (input_gradients [i ].type , NullType | DisconnectedType ):
410394 continue
411395
412- assert inp .type .ndim == batch_ndims + len (sig )
396+ assert inp .type .ndim == batch_ndim + len (sig )
413397
414- to_sum = [
398+ if to_sum : = [
415399 j
416400 for j , (inp_s , out_s ) in enumerate (
417401 zip (inp .type .shape , batch_shape , strict = False )
418402 )
419403 if inp_s == 1 and out_s != 1
420- ]
421- if to_sum :
422- rval [i ] = pt_sum (rval [i ], axis = to_sum , keepdims = True )
404+ ]:
405+ input_gradients [i ] = input_gradients [i ].sum (axis = to_sum , keepdims = True )
423406
424- return rval
407+ return input_gradients
425408
426409 def _create_node_gufunc (self , node : Apply , impl ) -> Callable :
427410 """Define (or retrieve) the node gufunc used in `perform`.
0 commit comments