|
13 | 13 | from collections.abc import Callable, Iterable, Sequence |
14 | 14 | from functools import _compose_mro, partial # type: ignore |
15 | 15 | from itertools import chain |
16 | | -from typing import TYPE_CHECKING, Literal |
| 16 | +from typing import Literal |
17 | 17 |
|
18 | | -import pytensor |
19 | 18 | from pytensor.configdefaults import config |
20 | 19 | from pytensor.graph import destroyhandler as dh |
21 | 20 | from pytensor.graph.basic import ( |
|
30 | 29 | from pytensor.graph.features import AlreadyThere, Feature |
31 | 30 | from pytensor.graph.fg import FunctionGraph, Output |
32 | 31 | from pytensor.graph.op import Op |
| 32 | +from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars |
33 | 33 | from pytensor.graph.utils import AssocList, InconsistencyError |
34 | 34 | from pytensor.misc.ordered_set import OrderedSet |
35 | 35 | from pytensor.utils import flatten |
36 | 36 |
|
37 | 37 |
|
38 | | -if TYPE_CHECKING: |
39 | | - from pytensor.graph.rewriting.unify import Var |
40 | | - |
41 | | - |
42 | 38 | _logger = logging.getLogger("pytensor.graph.rewriting.basic") |
43 | 39 |
|
44 | 40 | RemoveKeyType = Literal["remove"] |
|
59 | 55 | ] |
60 | 56 |
|
61 | 57 |
|
62 | | -class MetaNodeRewriterSkip(AssertionError): |
63 | | - """This is an `AssertionError`, but instead of having the |
64 | | - `MetaNodeRewriter` print the error, it just skip that |
65 | | - compilation. |
66 | | -
|
67 | | - """ |
68 | | - |
69 | | - |
70 | 58 | class Rewriter(abc.ABC): |
71 | 59 | """Abstract base class for graph/term rewriters.""" |
72 | 60 |
|
@@ -1414,8 +1402,6 @@ def __init__( |
1414 | 1402 | frequent `Op`, which will prevent the rewrite from being tried as often. |
1415 | 1403 |
|
1416 | 1404 | """ |
1417 | | - from pytensor.graph.rewriting.unify import convert_strs_to_vars |
1418 | | - |
1419 | 1405 | var_map: dict[str, Var] = {} |
1420 | 1406 | self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) |
1421 | 1407 | self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) |
@@ -1457,9 +1443,6 @@ def transform(self, fgraph, node, get_nodes=True): |
1457 | 1443 | if ret is not False and ret is not None: |
1458 | 1444 | return dict(zip(real_node.outputs, ret, strict=True)) |
1459 | 1445 |
|
1460 | | - if node.op != self.op: |
1461 | | - return False |
1462 | | - |
1463 | 1446 | if len(node.outputs) != 1: |
1464 | 1447 | # PatternNodeRewriter doesn't support replacing multi-output nodes |
1465 | 1448 | return False |
@@ -1488,11 +1471,13 @@ def transform(self, fgraph, node, get_nodes=True): |
1488 | 1471 |
|
1489 | 1472 | [old_out] = node.outputs |
1490 | 1473 | if not old_out.type.is_super(ret.type): |
| 1474 | + from pytensor.tensor.type import TensorType |
| 1475 | + |
1491 | 1476 | # Type doesn't match |
1492 | 1477 | if not ( |
1493 | 1478 | self.allow_cast |
1494 | | - and isinstance(old_out.type, pytensor.tensor.TensorType) |
1495 | | - and isinstance(ret.type, pytensor.tensor.TensorType) |
| 1479 | + and isinstance(old_out.type, TensorType) |
| 1480 | + and isinstance(ret.type, TensorType) |
1496 | 1481 | ): |
1497 | 1482 | return False |
1498 | 1483 |
|
@@ -2744,10 +2729,10 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): |
2744 | 2729 | otherwise. |
2745 | 2730 |
|
2746 | 2731 | """ |
2747 | | - if isinstance(f_or_fgraph, pytensor.compile.function.types.Function): |
2748 | | - fgraph = f_or_fgraph.maker.fgraph |
2749 | | - elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph): |
| 2732 | + if isinstance(f_or_fgraph, FunctionGraph): |
2750 | 2733 | fgraph = f_or_fgraph |
| 2734 | + elif hasattr(f_or_fgraph, "fgraph"): |
| 2735 | + fgraph = f_or_fgraph.fgraph |
2751 | 2736 | else: |
2752 | 2737 | raise ValueError("The type of f_or_fgraph is not supported") |
2753 | 2738 |
|
|
0 commit comments