Skip to content

Commit cdf6c86

Browse files
committed
wip
1 parent fc9ed22 commit cdf6c86

File tree

4 files changed

+135
-79
lines changed

4 files changed

+135
-79
lines changed

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2137,7 +2137,7 @@ private module Debug {
21372137
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
21382138
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
21392139
filepath.matches("%/main.rs") and
2140-
startline = 52
2140+
startline = 2909
21412141
)
21422142
}
21432143

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 120 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,8 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
748748
/**
749749
* A matching configuration for resolving types of struct expressions
750750
* like `Foo { bar = baz }`.
751+
*
752+
* This also includes nullary struct expressions like `None`.
751753
*/
752754
private module StructExprMatchingInput implements MatchingInputSig {
753755
private newtype TPos =
@@ -830,26 +832,86 @@ private module StructExprMatchingInput implements MatchingInputSig {
830832

831833
class AccessPosition = DeclarationPosition;
832834

833-
class Access extends StructExpr {
835+
abstract class Access extends AstNode {
836+
pragma[nomagic]
837+
abstract AstNode getNodeAt(AccessPosition apos);
838+
839+
pragma[nomagic]
840+
Type getInferredType(AccessPosition apos, TypePath path) {
841+
result = inferType(this.getNodeAt(apos), path)
842+
}
843+
844+
pragma[nomagic]
845+
abstract Path getStructPath();
846+
847+
pragma[nomagic]
848+
Declaration getTarget() { result = resolvePath(this.getStructPath()) }
849+
850+
pragma[nomagic]
834851
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
852+
exists(TypeMention tm, TypePath path0 |
853+
tm = this.getStructPath() and
854+
result = tm.resolveTypeAt(path0) and
855+
path0.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path)
856+
)
857+
}
858+
859+
/**
860+
* Holds if the return type of this call at `path` may have to be inferred
861+
* from the context.
862+
*/
863+
pragma[nomagic]
864+
predicate isContextTypedAt(DeclarationPosition pos, TypePath path) {
865+
// Struct declarations, such as `Foo::Bar{field = ...}`, may also be context typed
866+
exists(Declaration td, TypeParameter tp |
867+
td = this.getTarget() and
868+
pos.isStructPos() and
869+
tp = td.getDeclaredType(pos, path) and
870+
not exists(DeclarationPosition paramDpos |
871+
not paramDpos.isStructPos() and
872+
tp = td.getDeclaredType(paramDpos, _)
873+
) and
874+
// check that no explicit type arguments have been supplied for `tp`
875+
not exists(TypeArgumentPosition tapos |
876+
exists(this.getTypeArgument(tapos, _)) and
877+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
878+
)
879+
)
880+
}
881+
}
882+
883+
private class StructExprAccess extends Access, StructExpr {
884+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
885+
result = super.getTypeArgument(apos, path)
886+
or
835887
exists(TypePath suffix |
836888
suffix.isCons(TTypeParamTypeParameter(apos.asTypeParam()), path) and
837889
result = CertainTypeInference::inferCertainType(this, suffix)
838890
)
839891
}
840892

841-
AstNode getNodeAt(AccessPosition apos) {
893+
override AstNode getNodeAt(AccessPosition apos) {
842894
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
843895
or
844896
result = this and
845897
apos.isStructPos()
846898
}
847899

848-
Type getInferredType(AccessPosition apos, TypePath path) {
849-
result = inferType(this.getNodeAt(apos), path)
900+
override Path getStructPath() { result = this.getPath() }
901+
}
902+
903+
/**
904+
* A potential nullary struct/variant construction such as `None`.
905+
*/
906+
private class PathExprAccess extends Access, PathExpr {
907+
PathExprAccess() { not exists(CallExpr ce | this = ce.getFunction()) }
908+
909+
override AstNode getNodeAt(AccessPosition apos) {
910+
result = this and
911+
apos.isStructPos()
850912
}
851913

852-
Declaration getTarget() { result = resolvePath(this.getPath()) }
914+
override Path getStructPath() { result = this.getPath() }
853915
}
854916

855917
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
@@ -859,36 +921,32 @@ private module StructExprMatchingInput implements MatchingInputSig {
859921

860922
private module StructExprMatching = Matching<StructExprMatchingInput>;
861923

862-
/**
863-
* Gets the type of `n` at `path`, where `n` is either a struct expression or
864-
* a field expression of a struct expression.
865-
*/
866924
pragma[nomagic]
867-
private Type inferStructExprType(AstNode n, TypePath path) {
925+
private Type inferStructExprType0(AstNode n, boolean isReturn, TypePath path) {
868926
exists(StructExprMatchingInput::Access a, StructExprMatchingInput::AccessPosition apos |
869927
n = a.getNodeAt(apos) and
928+
if apos.isStructPos() then isReturn = true else isReturn = false
929+
|
870930
result = StructExprMatching::inferAccessType(a, apos, path)
931+
or
932+
a.isContextTypedAt(apos, path) and
933+
result = TContextType()
871934
)
872935
}
873936

937+
/**
938+
* Gets the type of `n` at `path`, where `n` is either a struct expression or
939+
* a field expression of a struct expression.
940+
*/
941+
private predicate inferStructExprType =
942+
ContextTyping::CheckContextTyping<inferStructExprType0/3>::check/2;
943+
874944
pragma[nomagic]
875945
private Type inferTupleRootType(AstNode n) {
876946
// `typeEquality` handles the non-root cases
877947
result = TTuple([n.(TupleExpr).getNumberOfFields(), n.(TuplePat).getTupleArity()])
878948
}
879949

880-
pragma[nomagic]
881-
private Type inferPathExprType(PathExpr pe, TypePath path) {
882-
// nullary struct/variant constructors
883-
not exists(CallExpr ce | pe = ce.getFunction()) and
884-
path.isEmpty() and
885-
exists(ItemNode i | i = resolvePath(pe.getPath()) |
886-
result = TEnum(i.(Variant).getEnum())
887-
or
888-
result = TStruct(i)
889-
)
890-
}
891-
892950
pragma[nomagic]
893951
private Path getCallExprPathQualifier(CallExpr ce) {
894952
result = CallExprImpl::getFunctionPath(ce).getQualifier()
@@ -982,7 +1040,7 @@ private module ContextTyping {
9821040
pragma[nomagic]
9831041
private predicate isContextTyped(AstNode n) { isContextTyped(n, _) }
9841042

985-
signature Type inferCallTypeSig(AstNode n, FunctionPosition pos, TypePath path);
1043+
signature Type inferCallTypeSig(AstNode n, boolean isReturn, TypePath path);
9861044

9871045
/**
9881046
* Given a predicate `inferCallType` for inferring the type of a call at a given
@@ -992,30 +1050,24 @@ private module ContextTyping {
9921050
*/
9931051
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
9941052
pragma[nomagic]
995-
private Type inferCallTypeFromContextCand(
996-
AstNode n, FunctionPosition pos, TypePath path, TypePath prefix
997-
) {
998-
result = inferCallType(n, pos, path) and
999-
not pos.isReturn() and
1053+
private Type inferCallTypeFromContextCand(AstNode n, TypePath path, TypePath prefix) {
1054+
result = inferCallType(n, false, path) and
10001055
isContextTyped(n) and
10011056
prefix = path
10021057
or
10031058
exists(TypePath mid |
1004-
result = inferCallTypeFromContextCand(n, pos, path, mid) and
1059+
result = inferCallTypeFromContextCand(n, path, mid) and
10051060
mid.isSnoc(prefix, _)
10061061
)
10071062
}
10081063

10091064
pragma[nomagic]
10101065
Type check(AstNode n, TypePath path) {
1011-
exists(FunctionPosition pos |
1012-
result = inferCallType(n, pos, path) and
1013-
pos.isReturn()
1014-
or
1015-
exists(TypePath prefix |
1016-
result = inferCallTypeFromContextCand(n, pos, path, prefix) and
1017-
isContextTyped(n, prefix)
1018-
)
1066+
result = inferCallType(n, true, path)
1067+
or
1068+
exists(TypePath prefix |
1069+
result = inferCallTypeFromContextCand(n, path, prefix) and
1070+
isContextTyped(n, prefix)
10191071
)
10201072
}
10211073
}
@@ -2131,11 +2183,13 @@ private Type inferMethodCallType0(
21312183
}
21322184

21332185
pragma[nomagic]
2134-
private Type inferMethodCallType1(
2135-
AstNode n, MethodCallMatchingInput::AccessPosition apos, TypePath path
2136-
) {
2137-
exists(MethodCallMatchingInput::Access a, string derefChainBorrow, TypePath path0 |
2138-
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
2186+
private Type inferMethodCallType1(AstNode n, boolean isReturn, TypePath path) {
2187+
exists(
2188+
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
2189+
string derefChainBorrow, TypePath path0
2190+
|
2191+
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0) and
2192+
if apos.isReturn() then isReturn = true else isReturn = false
21392193
|
21402194
(
21412195
not apos.isSelf()
@@ -2460,6 +2514,9 @@ private module NonMethodResolution {
24602514
/**
24612515
* A matching configuration for resolving types of calls like
24622516
* `foo::bar(baz)` where the target is not a method.
2517+
*
2518+
* This also includes "calls" to tuple variants and tuple structs such
2519+
* as `Result::Ok(42)`.
24632520
*/
24642521
private module NonMethodCallMatchingInput implements MatchingInputSig {
24652522
import FunctionPositionMatchingInput
@@ -2581,6 +2638,12 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25812638
pragma[nomagic]
25822639
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
25832640
result = getCallExprTypeArgument(this, apos).resolveTypeAt(path)
2641+
// todo: enum variants like `Self(...)`
2642+
}
2643+
2644+
pragma[nomagic]
2645+
AstNode getNodeAt(FunctionPosition pos) {
2646+
result = NonMethodResolution::NonMethodCall.super.getNodeAt(pos)
25842647
}
25852648

25862649
pragma[nomagic]
@@ -2591,26 +2654,23 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
25912654
result = inferType(this.getNodeAt(apos), path)
25922655
}
25932656

2657+
pragma[nomagic]
25942658
Declaration getTarget() {
2595-
result = this.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
2659+
result = super.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
25962660
}
25972661

2598-
/**
2599-
* Holds if the return type of this call at `path` may have to be inferred
2600-
* from the context.
2601-
*/
26022662
pragma[nomagic]
26032663
predicate isContextTypedAt(FunctionPosition pos, TypePath path) {
26042664
exists(ImplOrTraitItemNode i |
26052665
this.isContextTypedAt(i,
26062666
[
2607-
this.resolveCallTargetViaPathResolution().(NonMethodFunction),
2608-
this.resolveCallTargetViaTypeInference(i),
2609-
this.resolveTraitFunctionViaPathResolution(i)
2667+
super.resolveCallTargetViaPathResolution().(NonMethodFunction),
2668+
super.resolveCallTargetViaTypeInference(i),
2669+
super.resolveTraitFunctionViaPathResolution(i)
26102670
], pos, path)
26112671
)
26122672
or
2613-
// Tuple declarations, such as `None`, may also be context typed
2673+
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
26142674
exists(TupleDeclaration td, TypeParameter tp |
26152675
td = this.resolveCallTargetViaPathResolution() and
26162676
pos.isReturn() and
@@ -2629,10 +2689,11 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
26292689
private module NonMethodCallMatching = Matching<NonMethodCallMatchingInput>;
26302690

26312691
pragma[nomagic]
2632-
private Type inferNonMethodCallType0(
2633-
AstNode n, NonMethodCallMatchingInput::AccessPosition apos, TypePath path
2634-
) {
2635-
exists(NonMethodCallMatchingInput::Access a | n = a.getNodeAt(apos) |
2692+
private Type inferNonMethodCallType0(AstNode n, boolean isReturn, TypePath path) {
2693+
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
2694+
n = a.getNodeAt(apos) and
2695+
if apos.isReturn() then isReturn = true else isReturn = false
2696+
|
26362697
result = NonMethodCallMatching::inferAccessType(a, apos, path)
26372698
or
26382699
a.isContextTypedAt(apos, path) and
@@ -2715,12 +2776,11 @@ private module OperationMatchingInput implements MatchingInputSig {
27152776
private module OperationMatching = Matching<OperationMatchingInput>;
27162777

27172778
pragma[nomagic]
2718-
private Type inferOperationType0(
2719-
AstNode n, OperationMatchingInput::AccessPosition apos, TypePath path
2720-
) {
2721-
exists(OperationMatchingInput::Access a |
2779+
private Type inferOperationType0(AstNode n, boolean isReturn, TypePath path) {
2780+
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
27222781
n = a.getNodeAt(apos) and
2723-
result = OperationMatching::inferAccessType(a, apos, path)
2782+
result = OperationMatching::inferAccessType(a, apos, path) and
2783+
if apos.isReturn() then isReturn = true else isReturn = false
27242784
)
27252785
}
27262786

@@ -3488,8 +3548,6 @@ private module Cached {
34883548
or
34893549
result = inferStructExprType(n, path)
34903550
or
3491-
result = inferPathExprType(n, path)
3492-
or
34933551
result = inferMethodCallType(n, path)
34943552
or
34953553
result = inferNonMethodCallType(n, path)

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2908,12 +2908,12 @@ mod context_typed {
29082908
pub fn f() {
29092909
let x = None; // $ type=x:T.i32
29102910
let x: Option<i32> = x;
2911-
let x = Option::<i32>::None; // $ MISSING: type=x:T.i32
2912-
let x = Option::None::<i32>; // $ MISSING: type=x:T.i32
2911+
let x = Option::<i32>::None; // $ type=x:T.i32
2912+
let x = Option::None::<i32>; // $ type=x:T.i32
29132913

29142914
fn pin_option<T>(opt: Option<T>, x: T) {}
29152915

2916-
let x = None; // $ MISSING: type=x:T.i32
2916+
let x = None; // $ type=x:T.i32
29172917
pin_option(x, 0); // $ target=pin_option
29182918

29192919
enum MyEither<T1, T2> {
@@ -2932,7 +2932,7 @@ mod context_typed {
29322932
fn pin_my_either<T>(e: MyEither<T, String>, x: T) {}
29332933

29342934
#[rustfmt::skip]
2935-
let x = MyEither::B { // $ type=x:T2.String $ MISSING: type=x:T1.i32
2935+
let x = MyEither::B { // $ type=x:T1.i32 type=x:T2.String
29362936
right: String::new(), // $ target=new
29372937
};
29382938
pin_my_either(x, 0); // $ target=pin_my_either

0 commit comments

Comments
 (0)