@@ -3162,44 +3162,19 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
31623162 return np .allclose (x , ref , rtol = rtol , atol = atol )
31633163
31643164
3165- def _skip_mul_1 (r ):
3166- if r .owner and r .owner .op == mul :
3167- not_is_1 = [i for i in r .owner .inputs if not _is_1 (i )]
3168- if len (not_is_1 ) == 1 :
3169- return not_is_1 [0 ]
3170-
3171-
3172- def _is_1 (expr ):
3173- """
3174-
3175- Returns
3176- -------
3177- bool
3178- True iff expr is a constant close to 1.
3179-
3180- """
3181- try :
3182- v = get_underlying_scalar_constant_value (expr )
3183- return isclose (v , 1 )
3184- except NotScalarConstantError :
3185- return False
3186-
3187-
31883165logsigm_to_softplus = PatternNodeRewriter (
31893166 (log , (sigmoid , "x" )),
31903167 (neg , (softplus , (neg , "x" ))),
31913168 allow_multiple_clients = True ,
31923169 values_eq_approx = values_eq_approx_remove_inf ,
3193- skip_identities_fn = _skip_mul_1 ,
31943170 tracks = [sigmoid ],
31953171 get_nodes = get_clients_at_depth1 ,
31963172)
31973173log1msigm_to_softplus = PatternNodeRewriter (
3198- (log , (sub , dict ( pattern = "y" , constraint = _is_1 ) , (sigmoid , "x" ))),
3174+ (log , (sub , 1 , (sigmoid , "x" ))),
31993175 (neg , (softplus , "x" )),
32003176 allow_multiple_clients = True ,
32013177 values_eq_approx = values_eq_approx_remove_inf ,
3202- skip_identities_fn = _skip_mul_1 ,
32033178 tracks = [sigmoid ],
32043179 get_nodes = get_clients_at_depth2 ,
32053180)
@@ -3396,10 +3371,8 @@ def local_exp_over_1_plus_exp(fgraph, node):
33963371
33973372 if len (denom_rest ) == 0 :
33983373 return [new_num ]
3399- elif len (denom_rest ) == 1 :
3400- out = new_num / denom_rest [0 ]
34013374 else :
3402- out = new_num / mul (* denom_rest )
3375+ out = new_num / variadic_mul (* denom_rest )
34033376
34043377 copy_stack_trace (node .outputs [0 ], out )
34053378 return [out ]
@@ -3769,7 +3742,7 @@ def local_reciprocal_1_plus_exp(fgraph, node):
37693742
37703743# 1 - sigmoid(x) -> sigmoid(-x)
37713744local_1msigmoid = PatternNodeRewriter (
3772- (sub , dict ( pattern = "y" , constraint = _is_1 ) , (sigmoid , "x" )),
3745+ (sub , 1 , (sigmoid , "x" )),
37733746 (sigmoid , (neg , "x" )),
37743747 tracks = [sigmoid ],
37753748 get_nodes = get_clients_at_depth1 ,
0 commit comments