Skip to content

Commit fb8fddc

Browse files
authored
Try to instantiate TypeVars inside pt when possible (#24231)
This PR improves type inference for functions with unknown parameter types when the expected type contains type variables within union types. Previously, the compiler would only attempt to instantiate top-level `TypeVar`s, but now it recursively searches through union types (`OrType`) and flexible types to find and instantiate a nested type variable.
2 parents 714f3b6 + d71d682 commit fb8fddc

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,15 +1929,44 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19291929
NoType
19301930
}
19311931

1932-
pt.stripNull() match {
1933-
case pt: TypeVar
1934-
if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists =>
1935-
// try to instantiate `pt` if this is possible. If it does not
1936-
// work the error will be reported later in `inferredParam`,
1937-
// when we try to infer the parameter type.
1938-
isFullyDefined(pt, ForceDegree.flipBottom)
1939-
case _ =>
1940-
}
1932+
/** Try to instantiate one type variable bounded by function types that appear
1933+
* deeply inside `tp`, including union or intersection types.
1934+
*/
1935+
def tryToInstantiateDeeply(tp: Type): Boolean = tp.dealias match
1936+
case tp: AndOrType =>
1937+
tryToInstantiateDeeply(tp.tp1)
1938+
|| tryToInstantiateDeeply(tp.tp2)
1939+
case tp: FlexibleType =>
1940+
tryToInstantiateDeeply(tp.hi)
1941+
case tp: TypeVar if isConstrainedByFunctionType(tp) =>
1942+
// Only instantiate if the type variable is constrained by function types
1943+
isFullyDefined(tp, ForceDegree.flipBottom)
1944+
case _ => false
1945+
1946+
def isConstrainedByFunctionType(tvar: TypeVar): Boolean =
1947+
val origin = tvar.origin
1948+
val bounds = ctx.typerState.constraint.bounds(origin)
1949+
// The search is done by the best-effort, and we don't look into TypeVars recursively.
1950+
def containsFunctionType(tp: Type): Boolean = tp.dealias match
1951+
case tp if defn.isFunctionType(tp) => true
1952+
case SAMType(_, _) => true
1953+
case tp: AndOrType =>
1954+
containsFunctionType(tp.tp1) || containsFunctionType(tp.tp2)
1955+
case tp: FlexibleType =>
1956+
containsFunctionType(tp.hi)
1957+
case _ => false
1958+
containsFunctionType(bounds.lo) || containsFunctionType(bounds.hi)
1959+
1960+
if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists then
1961+
// Try to instantiate `pt` when possible.
1962+
// * If `pt` is a type variable, we try to instantiate it directly.
1963+
// * If `pt` is a more complex type, we try to instantiate it deeply by searching
1964+
// a nested type variable bounded by a function type to help infer parameter types.
1965+
// If it does not work the error will be reported later in `inferredParam`,
1966+
// when we try to infer the parameter type.
1967+
pt match
1968+
case pt: TypeVar => isFullyDefined(pt, ForceDegree.flipBottom)
1969+
case _ => tryToInstantiateDeeply(pt)
19411970

19421971
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree.srcPos)
19431972

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
def f[T](x: T): T = ???
3+
def f2[T](x: T | T): T = ???
4+
def f3[T](x: T | Null): T = ???
5+
def f4[T](x: Int | T): T = ???
6+
7+
trait MyOption[+T]
8+
9+
object MyOption:
10+
def apply[T](x: T | Null): MyOption[T] = ???
11+
12+
def test =
13+
val g: AnyRef => Boolean = f {
14+
x => x eq null // ok
15+
}
16+
val g2: AnyRef => Boolean = f2 {
17+
x => x eq null // ok
18+
}
19+
val g3: AnyRef => Boolean = f3 {
20+
x => x eq null // was error
21+
}
22+
val g4: AnyRef => Boolean = f4 {
23+
x => x eq null // was error
24+
}
25+
26+
val o1: MyOption[String] = MyOption(null)
27+
val o2: MyOption[String => Boolean] = MyOption {
28+
x => x.length > 0
29+
}
30+
val o3: MyOption[(String, String) => Boolean] = MyOption {
31+
(x, y) => x.length > y.length
32+
}
33+
34+
35+
class Box[T]
36+
val box: Box[Unit] = ???
37+
def ff1[T, U](x: T | U, y: Box[U]): T = ???
38+
def ff2[T, U](x: T & U): T = ???
39+
40+
def test2 =
41+
val a1: Any => Any = ff1(x => x, box)
42+
val a2: Any => Any = ff2(x => x)

0 commit comments

Comments
 (0)