diff --git a/clippy_lints/src/set_contains_or_insert.rs b/clippy_lints/src/set_contains_or_insert.rs index 688da33a1777..22790fb09d6a 100644 --- a/clippy_lints/src/set_contains_or_insert.rs +++ b/clippy_lints/src/set_contains_or_insert.rs @@ -2,10 +2,12 @@ use std::ops::ControlFlow; use clippy_utils::diagnostics::span_lint; use clippy_utils::res::MaybeDef; +use clippy_utils::ty::expr_sig; use clippy_utils::visitors::for_each_expr; use clippy_utils::{SpanlessEq, higher, peel_hir_expr_while, sym}; -use rustc_hir::{Expr, ExprKind, UnOp}; +use rustc_hir::{Expr, ExprKind, Node, UnOp}; use rustc_lint::{LateContext, LateLintPass}; +use rustc_middle::ty::{self, Binder}; use rustc_session::declare_lint_pass; use rustc_span::Span; use rustc_span::symbol::Symbol; @@ -112,6 +114,45 @@ fn try_parse_op_call<'tcx>( None } +fn is_set_mutated<'tcx>(cx: &LateContext<'tcx>, contains_expr: &OpExpr<'tcx>, expr: &'tcx Expr<'_>) -> bool { + let expr = peel_hir_expr_while(expr, |e| { + if let ExprKind::Unary(UnOp::Not, e) = e.kind { + Some(e) + } else { + None + } + }); + + if let ExprKind::MethodCall(_, receiver, ..) = expr.kind + && let receiver = receiver.peel_borrows() + && let receiver_ty = cx.typeck_results().expr_ty(receiver).peel_refs() + && (receiver_ty.is_diag_item(cx, sym::HashSet) || receiver_ty.is_diag_item(cx, sym::BTreeSet)) + && SpanlessEq::new(cx).eq_expr(contains_expr.receiver, receiver) + && let Some(method_def) = cx.typeck_results().type_dependent_def_id(expr.hir_id) + && let method_fn_sig = cx.tcx.fn_sig(method_def).instantiate_identity().skip_binder() + && let Some(self_ty) = method_fn_sig.inputs().first() + && let ty::Ref(_, _, ty::Mutability::Mut) = self_ty.kind() + { + return true; + } + + let child = expr.peel_borrows(); + let child_ty = cx.typeck_results().expr_ty(child).peel_refs(); + if (child_ty.is_diag_item(cx, sym::HashSet) || child_ty.is_diag_item(cx, sym::BTreeSet)) + && SpanlessEq::new(cx).eq_expr(contains_expr.receiver, child) + && let Node::Expr(parent) = cx.tcx.parent_hir_node(expr.hir_id) + && let ExprKind::Call(func, args) = parent.kind + && let Some(func_sig) = expr_sig(cx, func) + && let Some(set_index) = args.iter().position(|arg| arg.hir_id == expr.hir_id) + && let Some(set_ty) = func_sig.input(set_index).map(Binder::skip_binder) + && let ty::Ref(_, _, ty::Mutability::Mut) = set_ty.kind() + { + return true; + } + + false +} + fn find_insert_calls<'tcx>( cx: &LateContext<'tcx>, contains_expr: &OpExpr<'tcx>, @@ -122,9 +163,14 @@ fn find_insert_calls<'tcx>( && SpanlessEq::new(cx).eq_expr(contains_expr.receiver, insert_expr.receiver) && SpanlessEq::new(cx).eq_expr(contains_expr.value, insert_expr.value) { - ControlFlow::Break(insert_expr) - } else { - ControlFlow::Continue(()) + return ControlFlow::Break(Some(insert_expr)); + } + + if is_set_mutated(cx, contains_expr, e) { + return ControlFlow::Break(None); } + + ControlFlow::Continue(()) }) + .flatten() } diff --git a/tests/ui/set_contains_or_insert.rs b/tests/ui/set_contains_or_insert.rs index 575cfda139a4..6bba54bdb54a 100644 --- a/tests/ui/set_contains_or_insert.rs +++ b/tests/ui/set_contains_or_insert.rs @@ -164,3 +164,18 @@ fn main() { should_not_warn_hashset(); should_not_warn_btreeset(); } + +fn issue15990(s: &mut HashSet, v: usize) { + if !s.contains(&v) { + s.clear(); + s.insert(v); + } + + fn borrow_as_mut(v: usize, s: &mut HashSet) { + s.clear(); + } + if !s.contains(&v) { + borrow_as_mut(v, s); + s.insert(v); + } +}