Skip to content
Snippets Groups Projects
PicodataJDBCSparkExample.scala 3.94 KiB
Newer Older
package io.picodata

import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.functions._
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import scala.reflect.io.Directory
import scala.util.Using

import java.nio.file.Files


object PicodataJDBCSparkExample extends App {
  val logger = LoggerFactory.getLogger(PicodataJDBCSparkExample.getClass.getSimpleName)

  // 1. Set up the Spark session
  Using.Manager { use =>
    val warehouseLocation = Files.createTempDirectory("spark-warehouse").toFile
    val warehouseLocationPath = warehouseLocation.getAbsolutePath

    val spark = use(SparkSession.builder()
      .appName("Test Spark with picodata-jdbc")
      .master("local")
      .config("spark.ui.enabled", false)
      .config("spark.sql.warehouse.dir", warehouseLocationPath)
      .config("hive.metastore.warehouse.dir", warehouseLocationPath)
      .config(
        "javax.jdo.option.ConnectionURL",
        s"jdbc:derby:;databaseName=$warehouseLocationPath/tarantoolTest;create=true"
      )
      .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=log4j2.properties")
      .enableHiveSupport()
      .getOrCreate()
    )

    val sc = spark.sparkContext

    logger.info("Spark context created")

    // 2. Load the CSV into a DataFrame
    var df = spark.read
      .format("csv")
      .option("header", "true")
      .option("inferSchema", "true")
      .load("src/main/resources/onemillion.csv")
      .select(col("id"), col("unique_key"), col("book_name"), col("author"), col("year"))

    logger.info("Loaded 1M rows into memory")

    val jdbcUrl = "jdbc:picodata://localhost:5432/?user=sqluser&password=P@ssw0rd&sslmode=disable"
      // only needed if the table is not created on Picodata server
      // basic JDBC connector does not support primary keys
      val options = Map(
        ("driver", "io.picodata.jdbc.Driver"),
        ("url", jdbcUrl),
        ("dbtable", "test")
      )
      val jdbcOptions = new JDBCOptions(options)
      val connection = JdbcDialects.get(jdbcUrl).createConnectionFactory(jdbcOptions)(-1)
      var statement = connection.prepareStatement("DROP TABLE test")
      try {
        // IF EXISTS will be available in Picodata 24.6.1+
        statement.executeUpdate()
      } catch {
        case e: Exception => if (!e.getMessage.contains("test not found")) throw e
      }
      statement = connection.prepareStatement("CREATE TABLE test" +
        "(id INTEGER PRIMARY KEY, unique_key VARCHAR(1000), book_name VARCHAR(100), author VARCHAR(100), year INTEGER)")
      statement.executeUpdate()
      connection.close()

      // 3. Write a Dataset to a Picodata table
      df.write
        .format("jdbc")
        .option("driver", "io.picodata.jdbc.Driver")
        .mode(SaveMode.Append)
        // Picodata server connection options
        .option("url", jdbcUrl)
        // this option is important as it optimizes single INSERT statements into multi-value INSERTs
        .option("reWriteBatchedInserts", "true")
        // this option value can be tuned according to the number of Spark workers you have
        .option("numPartitions", "8")
        .option("batchsize", "1000")
        // table to create / overwrite
        .option("dbtable", "test")
        .save()

      logger.info("Saved 1M rows into Picodata table 'test'")

      // 4. Print first 3 rows from the table
      df = spark.read
        .format("jdbc")
        .option("driver", "io.picodata.jdbc.Driver")
        .option("url", jdbcUrl)
        .option("sslmode", "disable")
        .option("user", "sqluser")
        .option("password", "P@ssw0rd")
        .option("dbtable", "test")
        .load()
      df.printSchema()
      df.limit(3).show()
    } catch {
      case throwable: Throwable => throwable.printStackTrace()
    } finally {
      sc.stop()
      Directory(warehouseLocation).deleteRecursively()
    }
  }
}