Skip to content

Commit 6328cb4

Browse files
committed
added write options to set numberOfShards and collectionType
1 parent f8a4431 commit 6328cb4

File tree

4 files changed

+114
-11
lines changed

4 files changed

+114
-11
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ according to the related target collection definition and is different from the
203203
### Write Configuration
204204

205205
- `table`: target ArangoDB collection name (required)
206+
- `table.shards`: number of shards of the created collection (in case of SaveMode `Append` or `Overwrite`)
207+
- `table.type`: type (`document`|`edge`) of the created collection (in case of SaveMode `Append` or `Overwrite`)
206208
- `batch.size`: writing batch size, default `1000`
207209
- `wait.sync`: whether to wait until the documents have been synced to disk (`true`|`false`)
208210
- `confirm.truncate`: confirm to truncate table when using `SaveMode.Overwrite` mode, default `false`

arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoClient.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@ class ArangoClient(options: ArangoOptions) {
9999
.collection(options.writeOptions.collection)
100100
.exists()
101101

102-
def createCollection(): Unit = arangoDB
103-
.db(options.writeOptions.db)
104-
.collection(options.writeOptions.collection)
105-
.create(new CollectionCreateOptions()
106-
// TODO:
107-
// .`type`()
108-
// .numberOfShards()
109-
// .replicationFactor()
110-
// .minReplicationFactor()
111-
)
102+
def createCollection(): Unit = {
103+
val opts = new CollectionCreateOptions()
104+
options.writeOptions.numberOfShards.foreach(opts.numberOfShards(_))
105+
options.writeOptions.collectionType.foreach(ct => opts.`type`(ct.get()))
106+
107+
arangoDB
108+
.db(options.writeOptions.db)
109+
.collection(options.writeOptions.collection)
110+
.create(opts)
111+
}
112112

113113
def truncate(): Unit = arangoDB
114114
.db(options.writeOptions.db)

arangodb-spark-commons/src/main/scala/org/apache/spark/sql/arangodb/commons/ArangoOptions.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
package org.apache.spark.sql.arangodb.commons
2222

23-
import com.arangodb.ArangoDB
23+
import com.arangodb.{ArangoDB, entity}
2424
import com.arangodb.model.OverwriteMode
2525

2626
import java.io.ByteArrayInputStream
@@ -87,6 +87,8 @@ object ArangoOptions {
8787
val STREAM = "stream"
8888

8989
// write options
90+
val NUMBER_OF_SHARDS = "table.shards"
91+
val COLLECTION_TYPE = "table.type"
9092
val WAIT_FOR_SYNC = "wait.sync"
9193
val CONFIRM_TRUNCATE = "confirm.truncate"
9294
val OVERWRITE_MODE = "overwrite.mode"
@@ -178,6 +180,8 @@ class ArangoReadOptions(options: Map[String, String]) extends CommonOptions(opti
178180
class ArangoWriteOptions(options: Map[String, String]) extends CommonOptions(options) {
179181
val batchSize: Int = options.get(ArangoOptions.BATCH_SIZE).map(_.toInt).getOrElse(1000)
180182
val collection: String = getRequired(ArangoOptions.COLLECTION)
183+
val numberOfShards: Option[Int] = options.get(ArangoOptions.NUMBER_OF_SHARDS).map(_.toInt)
184+
val collectionType: Option[CollectionType] = options.get(ArangoOptions.COLLECTION_TYPE).map(CollectionType(_))
181185
val waitForSync: Option[Boolean] = options.get(ArangoOptions.WAIT_FOR_SYNC).map(_.toBoolean)
182186
val confirmTruncate: Boolean = options.getOrElse(ArangoOptions.CONFIRM_TRUNCATE, "false").toBoolean
183187
val overwriteMode: Option[OverwriteMode] = options.get(ArangoOptions.OVERWRITE_MODE).map(OverwriteMode.valueOf)
@@ -226,3 +230,23 @@ object Protocol {
226230
case _ => throw new IllegalArgumentException(s"${ArangoOptions.PROTOCOL}: $value")
227231
}
228232
}
233+
234+
sealed trait CollectionType {
235+
def get(): entity.CollectionType
236+
}
237+
238+
object CollectionType {
239+
case object DOCUMENT extends CollectionType {
240+
override def get() = entity.CollectionType.DOCUMENT
241+
}
242+
243+
case object EDGE extends CollectionType {
244+
override def get() = entity.CollectionType.EDGES
245+
}
246+
247+
def apply(value: String): CollectionType = value match {
248+
case "document" => DOCUMENT
249+
case "edge" => EDGE
250+
case _ => throw new IllegalArgumentException(s"${ArangoOptions.COLLECTION_TYPE}: $value")
251+
}
252+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package org.apache.spark.sql.arangodb.datasource.write
2+
3+
import com.arangodb.ArangoCollection
4+
import com.arangodb.entity.CollectionType
5+
import org.apache.spark.sql.SaveMode
6+
import org.apache.spark.sql.arangodb.commons.ArangoOptions
7+
import org.apache.spark.sql.arangodb.datasource.BaseSparkTest
8+
import org.assertj.core.api.Assertions.assertThat
9+
import org.junit.jupiter.api.BeforeEach
10+
import org.junit.jupiter.params.ParameterizedTest
11+
import org.junit.jupiter.params.provider.MethodSource
12+
13+
14+
class CreateCollectionTest extends BaseSparkTest {
15+
16+
private val collectionName = "chessPlayersCreateCollection"
17+
private val collection: ArangoCollection = db.collection(collectionName)
18+
19+
import spark.implicits._
20+
21+
private val df = Seq(
22+
("a/1", "b/1"),
23+
("a/2", "b/2"),
24+
("a/3", "b/3"),
25+
("a/4", "b/4"),
26+
("a/5", "b/5"),
27+
("a/6", "b/6")
28+
).toDF("_from", "_to")
29+
.repartition(3)
30+
31+
@BeforeEach
32+
def beforeEach(): Unit = {
33+
if (collection.exists()) {
34+
collection.drop()
35+
}
36+
}
37+
38+
@ParameterizedTest
39+
@MethodSource(Array("provideProtocolAndContentType"))
40+
def saveModeAppend(protocol: String, contentType: String): Unit = {
41+
df.write
42+
.format(BaseSparkTest.arangoDatasource)
43+
.mode(SaveMode.Append)
44+
.options(options + (
45+
ArangoOptions.COLLECTION -> collectionName,
46+
ArangoOptions.PROTOCOL -> protocol,
47+
ArangoOptions.CONTENT_TYPE -> contentType,
48+
ArangoOptions.NUMBER_OF_SHARDS -> "5",
49+
ArangoOptions.COLLECTION_TYPE -> "edge"
50+
))
51+
.save()
52+
53+
assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5)
54+
assertThat(collection.getProperties.getType.getType).isEqualTo(CollectionType.EDGES.getType)
55+
}
56+
57+
@ParameterizedTest
58+
@MethodSource(Array("provideProtocolAndContentType"))
59+
def saveModeOverwrite(protocol: String, contentType: String): Unit = {
60+
df.write
61+
.format(BaseSparkTest.arangoDatasource)
62+
.mode(SaveMode.Overwrite)
63+
.options(options + (
64+
ArangoOptions.COLLECTION -> collectionName,
65+
ArangoOptions.PROTOCOL -> protocol,
66+
ArangoOptions.CONTENT_TYPE -> contentType,
67+
ArangoOptions.CONFIRM_TRUNCATE -> "true",
68+
ArangoOptions.NUMBER_OF_SHARDS -> "5",
69+
ArangoOptions.COLLECTION_TYPE -> "edge"
70+
))
71+
.save()
72+
73+
assertThat(collection.getProperties.getNumberOfShards).isEqualTo(5)
74+
assertThat(collection.getProperties.getType.getType).isEqualTo(CollectionType.EDGES.getType)
75+
}
76+
77+
}

0 commit comments

Comments
 (0)