This commit is contained in:
Romit Girdhar 2018-09-07 16:29:38 -07:00
Родитель 7f9ad3173c
Коммит ea376c4f24
2 изменённых файлов: 54 добавлений и 54 удалений

Просмотреть файл

@ -27,6 +27,7 @@ import com.microsoft.azure.sqldb.spark.bulkcopy.BulkCopyMetadata
import com.microsoft.azure.sqldb.spark.connect._
import org.apache.spark.sql.catalyst.InternalRow
import com.microsoft.azure.sqldb.spark.connect.ConnectionUtils._
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.Attribute
import scala.collection.mutable.ListBuffer
@ -40,7 +41,7 @@ private[spark] object SQLWriter extends Logging {
var subset = false //This variable identifies whether you want to write to all columns of the SQL table or just select few.
var DRIVER_NAME: String = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
var connection: Connection = _
var ps: PreparedStatement = null
//var ps: PreparedStatement = null
override def toString: String = "SQLWriter"
@ -72,20 +73,27 @@ private[spark] object SQLWriter extends Logging {
if(connection.isClosed){
val test = 1
}
// connection.setCatalog(db)
//Using Bulk Copy
val bulkCopyConfig = Config(Map(
//TODO: Using Bulk Copy
/*val bulkCopyConfig = Config(Map(
"url" -> url,
"databaseName" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
"user" -> writeConfig.get[String](SqlDBConfig.User).get,
"password" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
"dbTable" -> table
))
try{
// data.bulkCopyToSqlDB()
data.bulkCopyToSqlDB()
} catch {
case e: Exception =>
log.error("Error writing batch data to SQL DB. Error details: "+ e)
throw e
}
*/
/* Getting Column information */
try{
val schema = queryExecution.analyzed.output
var schemaDatatype = new ListBuffer[String]()
var colNames = ""
@ -99,46 +107,45 @@ private[spark] object SQLWriter extends Logging {
})
}
//TODO: Option to overwrite
var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + ")" + " VALUES (" + values.substring(0, values.length-1) + ");"
var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + " , partitionid" + ")" + " VALUES (" + values.substring(0, values.length-1) + ",?" + ");"
ps = connection.prepareStatement(sql)
queryExecution.toRdd.foreachPartition(iter => {
iter.foreach(row => {
if(ps != null) {
try{
var i = 0
for(e <- schemaDatatype) {
val testVar = row.getString(i)
e match {
case "ByteType" => ps.setByte(i+1, row.getByte(i))
case "ShortType" => ps.setShort(i+1,row.getShort(i))
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
case "LongType" => ps.setLong(i+1, row.getLong(i))
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
case "StringType" => ps.setString(i+1, row.getString(i))
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
// case "DateType" => statement.setDate(i+1, row.getDate(i))
}
i += 1
queryExecution.toRdd.foreachPartition(iter => {
val ps = connection.prepareStatement(sql)
iter.foreach(row => {
val pid = TaskContext.get().partitionId()
if(ps != null) {
try{
var i = 0
for(e <- schemaDatatype) {
val testVar = row.getString(i)
println("Value: " + testVar + " ; i: " + i)
e match {
case "ByteType" => ps.setByte(i+1, row.getByte(i))
case "ShortType" => ps.setShort(i+1,row.getShort(i))
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
case "LongType" => ps.setLong(i+1, row.getLong(i))
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
case "StringType" => ps.setString(i+1, row.getString(i))
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
// case "DateType" => statement.setDate(i+1, row.getDate(i))
}
ps.execute()
} catch {
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
throw e //TODO: Give users the option to abort or continue
i += 1
}
ps.setInt(2, pid)
ps.execute()
} catch {
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
throw e //TODO: Give users the option to abort or continue
}
//streamToDB(sql, row, schemaDatatype)
})
}
//streamToDB(sql, row, schemaDatatype)
})
})
} catch {
case e: Exception =>
log.error("Error writing batch data to SQL DB. Error details: "+ e)
@ -182,14 +189,14 @@ private[spark] object SQLWriter extends Logging {
This method checks to see if the table exists in SQL. If it doesn't, it creates the table. This method also ensures that the data types of the data frame are compatible with that of Azure SQL database. If they aren't, it converts them and returns the converted data frame
*/
def streamToDB(
// statement: PreparedStatement,
ps: PreparedStatement,
sql: String,
row: InternalRow,
colDataTypes: ListBuffer[String]
): Unit = {
if(ps == null) {
ps = connection.prepareStatement(sql)
// ps = connection.prepareStatement(sql)
}

Просмотреть файл

@ -52,10 +52,9 @@ class SQLSinkTest extends StreamTest with SharedSQLContext {
)
var stream: DataStreamWriter[Row] = null
val input = MemoryStream[String]
input.addData("1", "2", "3", "4")
input.addData("1", "2", "3", "4", "5", "6", "7", "8")
var df = input.toDF().withColumnRenamed("value", "input")
//val df = createReader("/sqltestdata")
//TODO: Create Util functions to create SQL DB, table and truncate it and call the functions here
withTempDir { checkpointDir =>
config += ("checkpointLocation" -> checkpointDir.getCanonicalPath) //Adding Checkpoint Location
@ -68,18 +67,12 @@ class SQLSinkTest extends StreamTest with SharedSQLContext {
try {
failAfter(streamingTimeout){
Thread.sleep(100)
//streamStart.processAllAvailable()
streamStart.processAllAvailable()
}
//var testDF = spark.read.sqlDB(Config(config)).as[String].select("value").map(_.toInt)
checkDatasetUnorderly(spark.read.sqlDB(Config(config)).select($"input").as[String].map(_.toInt), 1, 2, 3, 4)
checkDatasetUnorderly(spark.read.sqlDB(Config(config)).select($"input").as[String].map(_.toInt), 1, 2, 3, 4, 5, 6, 7, 8)
} finally {
streamStart.stop()
}
//val df: DataFrame =
//TODO: Read data from SQL DB and check if it matches the data written (see EH implementation)
}
/*private def createWriter(inputDF: DataFrame, sqlConfig: AzureSQLConfig, withOutputMode: Option[OutputMode]): StreamingQuery = {