Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def read(cls, data: Buffer, data_file: str) -> CacheMeta | None:
# Misc classes.
EXTRA_ATTRS: Final[Tag] = 150
DT_SPEC: Final[Tag] = 151
PLUGIN_FLAGS: Final[Tag] = 152

END_TAG: Final[Tag] = 255

Expand Down
10 changes: 10 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
OverloadedFuncDef,
OverloadPart,
PassStmt,
PluginFlags,
PromoteExpr,
RaiseStmt,
RefExpr,
Expand Down Expand Up @@ -2205,6 +2206,11 @@ def check_method_override(
defn.name != "__replace__"
or defn.info.metadata.get("dataclass_tag") is None
)
and not (
defn.info
and (node := defn.info.get(defn.name))
and PluginFlags.should_skip_override_checks(node)
)
)
found_method_base_classes: list[TypeInfo] = []
for base in defn.info.mro[1:]:
Expand Down Expand Up @@ -3547,6 +3553,10 @@ def check_compatibility_all_supers(self, lvalue: RefExpr, rvalue: Expression) ->
and lvalue.kind in (MDEF, None) # None for Vars defined via self
and len(lvalue_node.info.bases) > 0
):
if not (
sym := lvalue_node.info.names.get(lvalue_node.name)
) or PluginFlags.should_skip_override_checks(sym):
return
for base in lvalue_node.info.mro[1:]:
tnode = base.names.get(lvalue_node.name)
if tnode is not None:
Expand Down
67 changes: 67 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
LIST_STR,
LITERAL_COMPLEX,
LITERAL_NONE,
PLUGIN_FLAGS,
Buffer,
Tag,
read_bool,
Expand Down Expand Up @@ -4469,6 +4470,57 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_placeholder_node(self)


class PluginFlags:
"""Checking customization for plugin-generated nodes.

This class is part of the public API. It can be used with the
`mypy.plugins.common.add_*_to_class` family of functions.

Args:
skip_override_checks: Allow this node to be an incompatible override.
A node having this flag set to True will not be required to be
LSP-compatible with the superclasses of its enclosing class.
This is helpful when the plugin generates a precise signature,
overriding a fallback signature defined in the base class.
This flag does not affect checking overrides *of* this node in
further subclasses.
"""

def __init__(self, *, skip_override_checks: bool = False) -> None:
self.skip_override_checks = skip_override_checks

@staticmethod
def should_skip_override_checks(node: SymbolTableNode) -> bool:
if node.plugin_flags is None:
return False
return node.plugin_flags.skip_override_checks

def serialize(self) -> JsonDict:
data: JsonDict = {".class": "PluginFlags"}
if self.skip_override_checks:
data["skip_override_checks"] = True
return data

@classmethod
def deserialize(cls, data: JsonDict) -> PluginFlags:
flags = PluginFlags()
if data.get("skip_override_checks"):
flags.skip_override_checks = True
return flags

def write(self, data: Buffer) -> None:
write_tag(data, PLUGIN_FLAGS)
write_bool(data, self.skip_override_checks)
write_tag(data, END_TAG)

@classmethod
def read(cls, data: Buffer) -> PluginFlags:
flags = PluginFlags()
flags.skip_override_checks = read_bool(data)
assert read_tag(data) == END_TAG
return flags


class SymbolTableNode:
"""Description of a name binding in a symbol table.

Expand Down Expand Up @@ -4537,6 +4589,7 @@ class SymbolTableNode:
"cross_ref",
"implicit",
"plugin_generated",
"plugin_flags",
"no_serialize",
)

Expand All @@ -4549,6 +4602,7 @@ def __init__(
module_hidden: bool = False,
*,
plugin_generated: bool = False,
plugin_flags: PluginFlags | None = None,
no_serialize: bool = False,
) -> None:
self.kind = kind
Expand All @@ -4558,6 +4612,7 @@ def __init__(
self.module_hidden = module_hidden
self.cross_ref: str | None = None
self.plugin_generated = plugin_generated
self.plugin_flags = plugin_flags
self.no_serialize = no_serialize

@property
Expand Down Expand Up @@ -4611,6 +4666,8 @@ def serialize(self, prefix: str, name: str) -> JsonDict:
data["implicit"] = True
if self.plugin_generated:
data["plugin_generated"] = True
if self.plugin_flags:
data["plugin_flags"] = self.plugin_flags.serialize()
if isinstance(self.node, MypyFile):
data["cross_ref"] = self.node.fullname
else:
Expand Down Expand Up @@ -4650,6 +4707,8 @@ def deserialize(cls, data: JsonDict) -> SymbolTableNode:
stnode.implicit = data["implicit"]
if "plugin_generated" in data:
stnode.plugin_generated = data["plugin_generated"]
if "plugin_flags" in data:
stnode.plugin_flags = PluginFlags.deserialize(data["plugin_flags"])
return stnode

def write(self, data: Buffer, prefix: str, name: str) -> None:
Expand Down Expand Up @@ -4681,6 +4740,10 @@ def write(self, data: Buffer, prefix: str, name: str) -> None:
if cross_ref is None:
assert self.node is not None
self.node.write(data)
if self.plugin_flags is None:
write_literal(data, None)
else:
self.plugin_flags.write(data)
write_tag(data, END_TAG)

@classmethod
Expand All @@ -4696,6 +4759,10 @@ def read(cls, data: Buffer) -> SymbolTableNode:
sym.node = read_symbol(data)
else:
sym.cross_ref = cross_ref
if (tag := read_tag(data)) == PLUGIN_FLAGS:
sym.plugin_flags = PluginFlags.read(data)
else:
assert tag == LITERAL_NONE
assert read_tag(data) == END_TAG
return sym

Expand Down
23 changes: 16 additions & 7 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Node,
OverloadedFuncDef,
PassStmt,
PluginFlags,
RefExpr,
SymbolTableNode,
TypeInfo,
Expand Down Expand Up @@ -220,6 +221,7 @@ class MethodSpec(NamedTuple):
return_type: Type
self_type: Type | None = None
tvar_defs: list[TypeVarType] | None = None
flags: PluginFlags | None = None


def add_method_to_class(
Expand All @@ -233,6 +235,7 @@ def add_method_to_class(
tvar_def: list[TypeVarType] | TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
flags: PluginFlags | None = None,
) -> FuncDef | Decorator:
"""Adds a new method to a class definition."""
_prepare_class_namespace(cls, name)
Expand All @@ -244,7 +247,13 @@ def add_method_to_class(
api,
cls.info,
name,
MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def),
MethodSpec(
args=args,
return_type=return_type,
self_type=self_type,
tvar_defs=tvar_def,
flags=flags,
),
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
Expand All @@ -260,6 +269,7 @@ def add_overloaded_method_to_class(
items: list[MethodSpec],
is_classmethod: bool = False,
is_staticmethod: bool = False,
flags: PluginFlags | None = None,
) -> OverloadedFuncDef:
"""Adds a new overloaded method to a class definition."""
assert len(items) >= 2, "Overloads must contain at least two cases"
Expand Down Expand Up @@ -294,8 +304,7 @@ def add_overloaded_method_to_class(
overload_def.info = cls.info
overload_def.is_class = is_classmethod
overload_def.is_static = is_staticmethod
sym = SymbolTableNode(MDEF, overload_def)
sym.plugin_generated = True
sym = SymbolTableNode(MDEF, overload_def, plugin_generated=True, plugin_flags=flags)

cls.info.names[name] = sym
cls.info.defn.defs.body.append(overload_def)
Expand Down Expand Up @@ -330,7 +339,7 @@ def _add_method_by_spec(
is_classmethod: bool,
is_staticmethod: bool,
) -> tuple[FuncDef | Decorator, SymbolTableNode]:
args, return_type, self_type, tvar_defs = spec
args, return_type, self_type, tvar_defs, flags = spec

assert not (
is_classmethod is True and is_staticmethod is True
Expand Down Expand Up @@ -383,8 +392,7 @@ def _add_method_by_spec(
sym.plugin_generated = True
return dec, sym

sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
sym = SymbolTableNode(MDEF, func, plugin_generated=True, plugin_flags=flags)
return func, sym


Expand All @@ -399,6 +407,7 @@ def add_attribute_to_class(
fullname: str | None = None,
is_classvar: bool = False,
overwrite_existing: bool = False,
flags: PluginFlags | None = None,
) -> Var:
"""
Adds a new attribute to a class definition.
Expand Down Expand Up @@ -428,7 +437,7 @@ def add_attribute_to_class(
node._fullname = info.fullname + "." + name

info.names[name] = SymbolTableNode(
MDEF, node, plugin_generated=True, no_serialize=no_serialize
MDEF, node, plugin_generated=True, no_serialize=no_serialize, plugin_flags=flags
)
return node

Expand Down
53 changes: 53 additions & 0 deletions test-data/unit/check-custom-plugin.test
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,59 @@ plugins=<ROOT>/test-data/unit/plugins/add_method.py
enable_error_code = explicit-override
[typing fixtures/typing-override.pyi]

[case testAddMethodPluginExplicitOverrideIgnoreCompat]
# flags: --python-version 3.12 --config-file tmp/mypy.ini --debug-serialize
from typing import TypeVar

T = TypeVar('T', bound=type)
def inject_foo(t: T) -> T:
return t

class BaseWithoutFoo:
pass

@inject_foo
class Child1(BaseWithoutFoo): ...

class BaseWithSameFoo:
attr: None
def meth_ok(self) -> None: ...
def meth_bad(self) -> None: ...

@inject_foo
class Child2(BaseWithSameFoo): ...

class BaseWithOtherFoo:
attr: int
def meth_ok(self) -> int: ...
def meth_bad(self) -> int: ...

# `attr` is not reported because add_attribute_to_class does not generate a statement (yet).
@inject_foo
class Child3(BaseWithOtherFoo): ... # E: Return type "None" of "meth_bad" incompatible with return type "int" in supertype "BaseWithOtherFoo"

@inject_foo
class ImmediatelyOverridden:
attr: int
def meth_ok(self) -> int: ...
def meth_bad(self) -> int: ...

@inject_foo
class Original:
...
class FurtherOverridden(Original):
attr: int # E: Incompatible types in assignment (expression has type "int", base class "Original" defined the type as "None")
def meth_ok(self) -> int: ... # E: Return type "int" of "meth_ok" incompatible with return type "None" in supertype "Original" \
# E: Method "meth_ok" is not using @override but is overriding a method in class "__main__.Original"
def meth_bad(self) -> int: ... # E: Return type "int" of "meth_bad" incompatible with return type "None" in supertype "Original" \
# E: Method "meth_bad" is not using @override but is overriding a method in class "__main__.Original"

[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_method_ignore_compat.py
enable_error_code = explicit-override
[typing fixtures/typing-override.pyi]

[case testCustomErrorCodePlugin]
# flags: --config-file tmp/mypy.ini --show-error-codes
def main() -> int:
Expand Down
46 changes: 46 additions & 0 deletions test-data/unit/plugins/add_method_ignore_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import Callable

from mypy.nodes import PluginFlags
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import add_attribute_to_class, add_method_to_class
from mypy.types import NoneType


class AddOverrideMethodPlugin(Plugin):
def get_class_decorator_hook_2(self, fullname: str) -> Callable[[ClassDefContext], bool] | None:
if fullname == "__main__.inject_foo":
return add_extra_methods_hook
return None


def add_extra_methods_hook(ctx: ClassDefContext) -> bool:
add_method_to_class(
ctx.api,
ctx.cls,
"meth_ok",
[],
NoneType(),
flags=PluginFlags(skip_override_checks=True)
)
add_method_to_class(
ctx.api,
ctx.cls,
"meth_bad",
[],
NoneType(),
flags=PluginFlags(skip_override_checks=False)
)
add_attribute_to_class(
ctx.api,
ctx.cls,
"attr",
NoneType(),
flags=PluginFlags(skip_override_checks=True)
)
return True


def plugin(version: str) -> type[AddOverrideMethodPlugin]:
return AddOverrideMethodPlugin