Skip to content

Commit 6d80ae4

Browse files
authored
fixed pushFilters for ReadMode.Query (#37)
1 parent 3495ea9 commit 6d80ae4

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

arangodb-spark-datasource-2.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoDataSourceReader.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ class ArangoDataSourceReader(tableSchema: StructType, options: ArangoDBConf) ext
3939
.map(it => ArangoPartition.ofCollection(it._1, it._2, new PushDownCtx(readSchema(), appliedPushableFilters), options))
4040

4141
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
42+
options.readOptions.readMode match {
43+
case ReadMode.Collection => pushFiltersReadModeCollection(filters)
44+
case ReadMode.Query => filters
45+
}
46+
}
47+
48+
private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = {
4249
// filters related to columnNameOfCorruptRecord are not pushed down
4350
val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
4451
val ignoredFilters = filters.filter(isCorruptRecordFilter)

arangodb-spark-datasource-3.1/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.apache.spark.sql.arangodb.datasource.reader
22

33
import org.apache.spark.internal.Logging
4-
import org.apache.spark.sql.arangodb.commons.ArangoDBConf
4+
import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode}
55
import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
66
import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
77
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
@@ -22,6 +22,13 @@ class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends
2222
override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options)
2323

2424
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
25+
options.readOptions.readMode match {
26+
case ReadMode.Collection => pushFiltersReadModeCollection(filters)
27+
case ReadMode.Query => filters
28+
}
29+
}
30+
31+
private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = {
2532
// filters related to columnNameOfCorruptRecord are not pushed down
2633
val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
2734
val ignoredFilters = filters.filter(isCorruptRecordFilter)

arangodb-spark-datasource-3.2/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package org.apache.spark.sql.arangodb.datasource.reader
22

33
import org.apache.spark.internal.Logging
4-
import org.apache.spark.sql.arangodb.commons.ArangoDBConf
4+
import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode}
55
import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter}
66
import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx
77
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
@@ -22,6 +22,13 @@ class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends
2222
override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options)
2323

2424
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
25+
options.readOptions.readMode match {
26+
case ReadMode.Collection => pushFiltersReadModeCollection(filters)
27+
case ReadMode.Query => filters
28+
}
29+
}
30+
31+
private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = {
2532
// filters related to columnNameOfCorruptRecord are not pushed down
2633
val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord)
2734
val ignoredFilters = filters.filter(isCorruptRecordFilter)

integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/ReadTest.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import org.apache.spark.SparkException
55
import org.apache.spark.sql.DataFrame
66
import org.apache.spark.sql.arangodb.commons.ArangoDBConf
77
import org.apache.spark.sql.functions.col
8-
import org.apache.spark.sql.types.{NullType, NumericType, StringType, StructField, StructType}
8+
import org.apache.spark.sql.types._
99
import org.assertj.core.api.Assertions.{assertThat, catchThrowable}
1010
import org.assertj.core.api.ThrowableAssert.ThrowingCallable
1111
import org.junit.jupiter.api.Assumptions.assumeTrue
@@ -167,6 +167,28 @@ class ReadTest extends BaseSparkTest {
167167
assertThat(df.count()).isEqualTo(10)
168168
}
169169

170+
@ParameterizedTest
171+
@MethodSource(Array("provideProtocolAndContentType"))
172+
def readQueryWithFilter(protocol: String, contentType: String): Unit = {
173+
val query =
174+
"""
175+
|FOR i IN 1..10
176+
| RETURN { idx: i, value: SHA1(i) }
177+
|""".stripMargin.replaceAll("\n", "")
178+
179+
val df = spark.read
180+
.format(BaseSparkTest.arangoDatasource)
181+
.options(options + (
182+
ArangoDBConf.QUERY -> query,
183+
ArangoDBConf.PROTOCOL -> protocol,
184+
ArangoDBConf.CONTENT_TYPE -> contentType
185+
))
186+
.load()
187+
.filter(col("idx") === 3)
188+
189+
assertThat(df.count()).isEqualTo(1)
190+
}
191+
170192
@ParameterizedTest
171193
@MethodSource(Array("provideProtocolAndContentType"))
172194
def reatTimeout(protocol: String, contentType: String): Unit = {

0 commit comments

Comments
 (0)