Skip to content

Commit 396d506

Browse files
committed
Cleanup imports in graph/rewriting/basic.py
1 parent b9c963e commit 396d506

File tree

1 file changed

+9
-24
lines changed

1 file changed

+9
-24
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from collections.abc import Callable, Iterable, Sequence
1414
from functools import _compose_mro, partial # type: ignore
1515
from itertools import chain
16-
from typing import TYPE_CHECKING, Literal
16+
from typing import Literal
1717

18-
import pytensor
1918
from pytensor.configdefaults import config
2019
from pytensor.graph import destroyhandler as dh
2120
from pytensor.graph.basic import (
@@ -30,15 +29,12 @@
3029
from pytensor.graph.features import AlreadyThere, Feature
3130
from pytensor.graph.fg import FunctionGraph, Output
3231
from pytensor.graph.op import Op
32+
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
3333
from pytensor.graph.utils import AssocList, InconsistencyError
3434
from pytensor.misc.ordered_set import OrderedSet
3535
from pytensor.utils import flatten
3636

3737

38-
if TYPE_CHECKING:
39-
from pytensor.graph.rewriting.unify import Var
40-
41-
4238
_logger = logging.getLogger("pytensor.graph.rewriting.basic")
4339

4440
RemoveKeyType = Literal["remove"]
@@ -59,14 +55,6 @@
5955
]
6056

6157

62-
class MetaNodeRewriterSkip(AssertionError):
63-
"""This is an `AssertionError`, but instead of having the
64-
`MetaNodeRewriter` print the error, it just skip that
65-
compilation.
66-
67-
"""
68-
69-
7058
class Rewriter(abc.ABC):
7159
"""Abstract base class for graph/term rewriters."""
7260

@@ -1414,8 +1402,6 @@ def __init__(
14141402
frequent `Op`, which will prevent the rewrite from being tried as often.
14151403
14161404
"""
1417-
from pytensor.graph.rewriting.unify import convert_strs_to_vars
1418-
14191405
var_map: dict[str, Var] = {}
14201406
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
14211407
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
@@ -1457,9 +1443,6 @@ def transform(self, fgraph, node, get_nodes=True):
14571443
if ret is not False and ret is not None:
14581444
return dict(zip(real_node.outputs, ret, strict=True))
14591445

1460-
if node.op != self.op:
1461-
return False
1462-
14631446
if len(node.outputs) != 1:
14641447
# PatternNodeRewriter doesn't support replacing multi-output nodes
14651448
return False
@@ -1488,11 +1471,13 @@ def transform(self, fgraph, node, get_nodes=True):
14881471

14891472
[old_out] = node.outputs
14901473
if not old_out.type.is_super(ret.type):
1474+
from pytensor.tensor.type import TensorType
1475+
14911476
# Type doesn't match
14921477
if not (
14931478
self.allow_cast
1494-
and isinstance(old_out.type, pytensor.tensor.TensorType)
1495-
and isinstance(ret.type, pytensor.tensor.TensorType)
1479+
and isinstance(old_out.type, TensorType)
1480+
and isinstance(ret.type, TensorType)
14961481
):
14971482
return False
14981483

@@ -2744,10 +2729,10 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
27442729
otherwise.
27452730
27462731
"""
2747-
if isinstance(f_or_fgraph, pytensor.compile.function.types.Function):
2748-
fgraph = f_or_fgraph.maker.fgraph
2749-
elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph):
2732+
if isinstance(f_or_fgraph, FunctionGraph):
27502733
fgraph = f_or_fgraph
2734+
elif hasattr(f_or_fgraph, "fgraph"):
2735+
fgraph = f_or_fgraph.fgraph
27512736
else:
27522737
raise ValueError("The type of f_or_fgraph is not supported")
27532738

0 commit comments

Comments
 (0)