Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.assertj.core.api.ListAssert;
import org.jspecify.annotations.Nullable;

import org.springframework.lang.CheckReturnValue;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.contentOf;

Expand Down Expand Up @@ -172,6 +174,7 @@ JarAssert doesNotHaveEntryWithNameStartingWith(String prefix) {
return this;
}

@CheckReturnValue
ListAssert<String> entryNamesInPath(String path) {
List<String> matches = new ArrayList<>();
withJarFile((jarFile) -> withEntries(jarFile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
* @author Phillip Webb
* @author Dmytro Nosan
* @author Moritz Halbritter
* @author Stefano Cordio
*/
public abstract class ArchitectureCheck extends DefaultTask {

Expand All @@ -85,6 +86,8 @@ public ArchitectureCheck() {
getRules().addAll(whenMainSources(() -> List
.of(ArchitectureRules.allBeanMethodsShouldReturnNonPrivateType(), ArchitectureRules
.allBeanMethodsShouldNotHaveConditionalOnClassAnnotation(getConditionalOnClassAnnotation().get()))));
getRules().addAll(whenMainSources(() -> Collections.singletonList(
ArchitectureRules.allCustomAssertionMethodsNotReturningSelfShouldBeAnnotatedWithCheckReturnValue())));
getRules().addAll(and(getNullMarkedEnabled(), isMainSourceSet()).map(whenTrue(() -> Collections.singletonList(
ArchitectureRules.packagesShouldBeAnnotatedWithNullMarked(getNullMarkedIgnoredPackages().get())))));
getRuleDescriptions().set(getRules().map(this::asDescriptions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.annotation.Role;
import org.springframework.lang.CheckReturnValue;
import org.springframework.util.ResourceUtils;

/**
Expand All @@ -75,6 +76,7 @@
* @author Phillip Webb
* @author Ngoc Nhan
* @author Moritz Halbritter
* @author Stefano Cordio
*/
final class ArchitectureRules {

Expand Down Expand Up @@ -139,6 +141,30 @@ static ArchRule allBeanMethodsShouldNotHaveConditionalOnClassAnnotation(String a
.allowEmptyShould(true);
}

static ArchRule allCustomAssertionMethodsNotReturningSelfShouldBeAnnotatedWithCheckReturnValue() {
return ArchRuleDefinition.methods()
.that()
.areDeclaredInClassesThat()
.implement("org.assertj.core.api.Assert")
.and()
.arePublic()
.and(dontReturnSelfType())
.should()
.beAnnotatedWith(CheckReturnValue.class)
.allowEmptyShould(true);
}

private static DescribedPredicate<JavaMethod> dontReturnSelfType() {
return DescribedPredicate.describe("don't return self type",
(method) -> !method.getRawReturnType().equals(method.getOwner())
|| isOverridingMethodNotReturningSelfType(method));
}

private static boolean isOverridingMethodNotReturningSelfType(JavaMethod method) {
return superMethods(method).anyMatch((superMethod) -> isOverridden(superMethod, method)
&& !superMethod.getOwner().equals(method.getRawReturnType()));
}

private static ArchRule allPackagesShouldBeFreeOfTangles() {
return SlicesRuleDefinition.slices()
.matching("(**)")
Expand Down Expand Up @@ -554,39 +580,39 @@ public void check(JavaPackage item, ConditionEvents events) {
};
}

private static class OverridesPublicMethod<T extends JavaMember> extends DescribedPredicate<T> {
private static Stream<JavaMethod> superMethods(JavaMethod method) {
Stream<JavaMethod> superClassMethods = method.getOwner()
.getAllRawSuperclasses()
.stream()
.flatMap((superClass) -> superClass.getMethods().stream());
Stream<JavaMethod> interfaceMethods = method.getOwner()
.getAllRawInterfaces()
.stream()
.flatMap((iface) -> iface.getMethods().stream());
return Stream.concat(superClassMethods, interfaceMethods);
}

OverridesPublicMethod() {
super("overrides public method");
}
private static boolean isOverridden(JavaMethod superMethod, JavaMethod method) {
return superMethod.getName().equals(method.getName())
&& superMethod.getRawParameterTypes().size() == method.getRawParameterTypes().size()
&& superMethod.getDescriptor().equals(method.getDescriptor());
}

private static final class OverridesPublicMethod<T extends JavaMember> implements Predicate<T> {

@Override
public boolean test(T member) {
if (!(member instanceof JavaMethod javaMethod)) {
return false;
}
Stream<JavaMethod> superClassMethods = member.getOwner()
.getAllRawSuperclasses()
.stream()
.flatMap((superClass) -> superClass.getMethods().stream());
Stream<JavaMethod> interfaceMethods = member.getOwner()
.getAllRawInterfaces()
.stream()
.flatMap((iface) -> iface.getMethods().stream());
return Stream.concat(superClassMethods, interfaceMethods)
return superMethods(javaMethod)
.anyMatch((superMethod) -> isPublic(superMethod) && isOverridden(superMethod, javaMethod));
}

private boolean isPublic(JavaMethod method) {
private static boolean isPublic(JavaMethod method) {
return method.getModifiers().contains(JavaModifier.PUBLIC);
}

private boolean isOverridden(JavaMethod superMethod, JavaMethod method) {
return superMethod.getName().equals(method.getName())
&& superMethod.getRawParameterTypes().size() == method.getRawParameterTypes().size()
&& superMethod.getDescriptor().equals(method.getDescriptor());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@
* @author Scott Frederick
* @author Ivan Malutin
* @author Dmytro Nosan
* @author Stefano Cordio
*/
class ArchitectureCheckTests {

private static final String ASSERTJ_CORE = "org.assertj:assertj-core:3.27.4";

private static final String SPRING_CONTEXT = "org.springframework:spring-context:6.2.9";

private static final String SPRING_CORE = "org.springframework:spring-core:6.2.9";

private static final String JUNIT_JUPITER = "org.junit.jupiter:junit-jupiter:5.12.0";

private static final String SPRING_INTEGRATION_JMX = "org.springframework.integration:spring-integration-jmx:6.5.1";
Expand Down Expand Up @@ -340,6 +345,22 @@ void whenConditionalOnClassUsedOnBeanMethodsWithTestSourcesShouldSucceedAndWrite
build(gradleBuild, Task.CHECK_ARCHITECTURE_TEST);
}

@Test
void whenCustomAssertionMethodNotReturningSelfIsAnnotatedWithCheckReturnValueShouldSucceedAndWriteEmptyReport()
throws IOException {
prepareTask(Task.CHECK_ARCHITECTURE_MAIN, "assertj/checkReturnValue");
build(this.gradleBuild.withDependencies(ASSERTJ_CORE, SPRING_CORE), Task.CHECK_ARCHITECTURE_MAIN);
}

@Test
void whenCustomAssertionMethodNotReturningSelfIsNotAnnotatedWithCheckReturnValueShouldFailAndWriteReport()
throws IOException {
prepareTask(Task.CHECK_ARCHITECTURE_MAIN, "assertj/noCheckReturnValue");
buildAndFail(this.gradleBuild.withDependencies(ASSERTJ_CORE), Task.CHECK_ARCHITECTURE_MAIN,
"methods that are declared in classes that implement org.assertj.core.api.Assert"
+ " and are public and don't return self type should be annotated with" + " @CheckReturnValue");
}

private void prepareTask(Task task, String... sourceDirectories) throws IOException {
for (String sourceDirectory : sourceDirectories) {
FileSystemUtils.copyRecursively(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2012-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.build.architecture.assertj.checkReturnValue;

import org.assertj.core.api.AbstractAssert;

import org.springframework.lang.CheckReturnValue;

public class WithCheckReturnValue extends AbstractAssert<WithCheckReturnValue, Object> {

WithCheckReturnValue() {
super(null, WithCheckReturnValue.class);
}

@CheckReturnValue
public Object notReturningSelf() {
return new Object();
}

@Override
public WithCheckReturnValue isEqualTo(Object expected) {
return super.isEqualTo(expected);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2012-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.build.architecture.assertj.noCheckReturnValue;

import org.assertj.core.api.AbstractAssert;

public class NoCheckReturnValue extends AbstractAssert<NoCheckReturnValue, Object> {

NoCheckReturnValue() {
super(null, NoCheckReturnValue.class);
}

public Object notReturningSelf() {
return new Object();
}

}
3 changes: 3 additions & 0 deletions config/checkstyle/import-control.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
<allow pkg="io.micrometer.observation" />
<disallow pkg="io.micrometer" />

<!-- Improve DevEx with fluent APIs -->
<allow class="org.springframework.lang.CheckReturnValue" />

<!-- Use JSpecify for nullability (not Spring) -->
<allow class="org.springframework.lang.Contract" />
<disallow pkg="org.springframework.lang" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.assertj.core.api.AbstractObjectArrayAssert;
import org.assertj.core.api.AbstractObjectAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.MapAssert;
import org.assertj.core.error.BasicErrorMessageFactory;
import org.jspecify.annotations.Nullable;
Expand All @@ -37,6 +36,7 @@
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.lang.CheckReturnValue;
import org.springframework.util.Assert;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -224,13 +224,14 @@ public ApplicationContextAssert<C> doesNotHaveBean(String name) {
* @return array assertions for the bean names
* @throws AssertionError if the application context did not start
*/
@CheckReturnValue
public <T> AbstractObjectArrayAssert<?, String> getBeanNames(Class<T> type) {
if (this.startupFailure != null) {
throwAssertionError(contextFailedToStartWhenExpecting(this.startupFailure,
"to get beans names with type:%n <%s>", type));
}
return Assertions.assertThat(getApplicationContext().getBeanNamesForType(type))
.as("Bean names of type <%s> from <%s>", type, getApplicationContext());
return assertThat(getApplicationContext().getBeanNamesForType(type)).as("Bean names of type <%s> from <%s>",
type, getApplicationContext());
}

/**
Expand All @@ -249,6 +250,7 @@ public <T> AbstractObjectArrayAssert<?, String> getBeanNames(Class<T> type) {
* @throws AssertionError if the application context contains multiple beans of the
* given type
*/
@CheckReturnValue
public <T> AbstractObjectAssert<?, T> getBean(Class<T> type) {
return getBean(type, Scope.INCLUDE_ANCESTORS);
}
Expand All @@ -270,6 +272,7 @@ public <T> AbstractObjectAssert<?, T> getBean(Class<T> type) {
* @throws AssertionError if the application context contains multiple beans of the
* given type
*/
@CheckReturnValue
public <T> AbstractObjectAssert<?, T> getBean(Class<T> type, Scope scope) {
Assert.notNull(scope, "'scope' must not be null");
if (this.startupFailure != null) {
Expand All @@ -284,7 +287,7 @@ public <T> AbstractObjectAssert<?, T> getBean(Class<T> type, Scope scope) {
getApplicationContext(), type, names));
}
T bean = (name != null) ? getApplicationContext().getBean(name, type) : null;
return Assertions.assertThat(bean).as("Bean of type <%s> from <%s>", type, getApplicationContext());
return assertThat(bean).as("Bean of type <%s> from <%s>", type, getApplicationContext());
}

private @Nullable String getPrimary(String[] names, Scope scope) {
Expand Down Expand Up @@ -330,13 +333,14 @@ private boolean isPrimary(String name, Scope scope) {
* is found
* @throws AssertionError if the application context did not start
*/
@CheckReturnValue
public AbstractObjectAssert<?, Object> getBean(String name) {
if (this.startupFailure != null) {
throwAssertionError(
contextFailedToStartWhenExpecting(this.startupFailure, "to contain a bean of name:%n <%s>", name));
}
Object bean = findBean(name);
return Assertions.assertThat(bean).as("Bean of name <%s> from <%s>", name, getApplicationContext());
return assertThat(bean).as("Bean of name <%s> from <%s>", name, getApplicationContext());
}

/**
Expand All @@ -357,6 +361,7 @@ public AbstractObjectAssert<?, Object> getBean(String name) {
* name but a different type
*/
@SuppressWarnings("unchecked")
@CheckReturnValue
public <T> AbstractObjectAssert<?, T> getBean(String name, Class<T> type) {
if (this.startupFailure != null) {
throwAssertionError(contextFailedToStartWhenExpecting(this.startupFailure,
Expand All @@ -368,8 +373,8 @@ public <T> AbstractObjectAssert<?, T> getBean(String name, Class<T> type) {
"%nExpecting:%n <%s>%nto contain a bean of name:%n <%s> (%s)%nbut found:%n <%s> of type <%s>",
getApplicationContext(), name, type, bean, bean.getClass()));
}
return Assertions.assertThat((T) bean)
.as("Bean of name <%s> and type <%s> from <%s>", name, type, getApplicationContext());
return assertThat((T) bean).as("Bean of name <%s> and type <%s> from <%s>", name, type,
getApplicationContext());
}

private @Nullable Object findBean(String name) {
Expand All @@ -395,6 +400,7 @@ public <T> AbstractObjectAssert<?, T> getBean(String name, Class<T> type) {
* no beans are found
* @throws AssertionError if the application context did not start
*/
@CheckReturnValue
public <T> MapAssert<String, T> getBeans(Class<T> type) {
return getBeans(type, Scope.INCLUDE_ANCESTORS);
}
Expand All @@ -414,14 +420,15 @@ public <T> MapAssert<String, T> getBeans(Class<T> type) {
* no beans are found
* @throws AssertionError if the application context did not start
*/
@CheckReturnValue
public <T> MapAssert<String, T> getBeans(Class<T> type, Scope scope) {
Assert.notNull(scope, "'scope' must not be null");
if (this.startupFailure != null) {
throwAssertionError(
contextFailedToStartWhenExpecting(this.startupFailure, "to get beans of type:%n <%s>", type));
}
return Assertions.assertThat(scope.getBeansOfType(getApplicationContext(), type))
.as("Beans of type <%s> from <%s>", type, getApplicationContext());
return assertThat(scope.getBeansOfType(getApplicationContext(), type)).as("Beans of type <%s> from <%s>", type,
getApplicationContext());
}

/**
Expand All @@ -434,6 +441,7 @@ public <T> MapAssert<String, T> getBeans(Class<T> type, Scope scope) {
* @return assertions on the cause of the failure
* @throws AssertionError if the application context started without a failure
*/
@CheckReturnValue
public AbstractThrowableAssert<?, ? extends Throwable> getFailure() {
hasFailed();
return assertThat(this.startupFailure);
Expand Down
Loading
Loading