Skip to content
Snippets Groups Projects
PicodataJDBCSparkExample.scala 2.79 KiB
Newer Older
package io.picodata

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import scala.reflect.io.Directory
import scala.util.Using

import java.nio.file.Files


object PicodataJDBCSparkExample extends App {
  val logger = LoggerFactory.getLogger(PicodataJDBCSparkExample.getClass.getSimpleName)

  // 1. Set up the Spark session
  Using.Manager { use =>
    val warehouseLocation = Files.createTempDirectory("spark-warehouse").toFile
    val warehouseLocationPath = warehouseLocation.getAbsolutePath

    val spark = use(SparkSession.builder()
      .appName("Test Spark with picodata-jdbc")
      .master("local")
      .config("spark.ui.enabled", false)
      .config("spark.sql.warehouse.dir", warehouseLocationPath)
      .config("hive.metastore.warehouse.dir", warehouseLocationPath)
      .config(
        "javax.jdo.option.ConnectionURL",
        s"jdbc:derby:;databaseName=$warehouseLocationPath/tarantoolTest;create=true"
      )
      .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=log4j2.properties")
      .enableHiveSupport()
      .getOrCreate()
    )

    val sc = spark.sparkContext

    logger.info("Spark context created")

    // 2. Load the CSV into a DataFrame
    var df = spark.read
      .format("csv")
      .load("src/main/resources/onemillion.csv")

    logger.info("Loaded 1M rows into memory")

    val jdbcUrl = "jdbc:picodata://localhost:5432/"

    try {
      // 3. Write a Dataset to a Picodata table
      df.write
        .format("jdbc")
        .option("driver", "io.picodata.jdbc.Driver")
        .mode(SaveMode.Overwrite)
        // Picodata server connection options
        .option("url", jdbcUrl)
        .option("sslmode", "disable")
        .option("user", "sqluser")
        .option("password", "P@ssw0rd")
        // this option is important as it optimizes single INSERT statements into multi-value INSERTs
        .option("reWriteBatchedInserts", "true")
        // this option value can be tuned according to the number of Spark workers you have
        .option("numPartitions", "8")
        .option("batchsize", "1000")
        // table to create / overwrite
        .option("dbtable", "test")
        .save()

      logger.info("Saved 1M rows into Picodata table 'test'")

      // 4. Print first 3 rows from the table
      df = spark.read
        .format("jdbc")
        .option("driver", "io.picodata.jdbc.Driver")
        .option("url", jdbcUrl)
        .option("sslmode", "disable")
        .option("user", "sqluser")
        .option("password", "P@ssw0rd")
        .option("dbtable", "test")
        .load()
      df.printSchema()
      df.limit(3).show()
    } catch {
      case throwable: Throwable => throwable.printStackTrace()
    } finally {
      sc.stop()
      Directory(warehouseLocation).deleteRecursively()
    }
  }
}